HTML conversions sometimes display errors due to content that did not convert correctly from the source. This paper uses the following packages that are not yet supported by the HTML conversion tool. Feedback on these issues are not necessary; they are known and are being worked on.

  • failed: layouts

Authors: achieve the best HTML results from your LaTeX submissions by following these best practices.

License: CC BY 4.0
arXiv:2211.00086v2 [cs.LG] 03 Jan 2024

Disentangled (Un)Controllable Features

1st Jacob E. Kooi Quantitative Data Analytics
Vrije Universiteit Amsterdam
Amsterdam, Netherlands
[email protected]
   2nd Mark Hoogendoorn Quantitative Data Analytics
Vrije Universiteit Amsterdam
Amsterdam, Netherlands
[email protected]
   3rd Vincent Francois-Lavet Quantitative Data Analytics
Vrije Universiteit Amsterdam
Amsterdam, Netherlands
[email protected]
Abstract

In the context of MDPs with high-dimensional states, downstream tasks are predominantly applied on a compressed, low-dimensional representation of the original input space. A variety of learning objectives have therefore been used to attain useful representations. However, these representations usually lack interpretability of the different features. We present a novel approach that is able to disentangle latent features into a controllable and an uncontrollable partition. We illustrate that the resulting partitioned representations are easily interpretable on three types of environments and show that, in a distribution of procedurally generated maze environments, it is feasible to interpretably employ a planning algorithm in the isolated controllable latent partition.

I Introduction

Learning from high-dimensional data remains a challenging task. Particularly for reinforcement learning (RL), the complexity and high dimensionality of the Markov Decision Process (MDP) [1] states often leads to complex or intractable solutions. In order to facilitate learning from high-dimensional input data, an encoder architecture can be used to compress the inputs into a lower-dimensional latent representation. To this extent, a plethora of work has successfully focused on discovering a compressed encoded representation that accommodates the underlying features for the task at hand [2, 3, 4, 5, 6, 7, 8].

The resulting low-dimensional representations however seldom contain specific disentangled features, which leads to disorganized latent information. This means that the individual latent states can represent the information from the state in any arbitrary way. The result is a representation with poor interpretability, as the latent states cannot be connected to certain attributes of the original observation space (e.g, the x-y coordinates of the agent). Prior work in structuring a latent representation has shown notions and use of interpretability in MDP representations [9]. When expanding this notion of interpretability to be compatible with RL, it has been argued that the controllable features should be an important element of a latent representation, since it generally represents what is directly influenced by the policy. In this light, [10] have introduced the concept of isolating and disentangling controllable features in a low-dimensional maze environment, by means of a selectivity loss. Furthermore, [11] took an object-centric approach to isolate distinct objects in MDPs and [12] showed theoretical foundations for this isolation in a weakly-supervised controllable setting. Controllable features however only represent a fragment of an environment, where in many cases the uncontrollable features are of equal importance. For example, in the context of a distribution of mazes, for the prediction of the next controllable (agent) state following an action, the information about the wall structure is crucial (see Fig. 1). We therefore hypothesize that a thorough representation should incorporate controllable and uncontrollable features, ideally in a disentangled, interpretable arrangement; Intepretability is crucial for future real-world deployment [13], while an additional benefit would be that the separation of the controllable and uncontrollable features can be exploited in downstream algorithms such as planning.

Our contribution consists of an algorithm that, showcased in three different MDP settings, explicitly disentangles the latent representation into a controllable and an uncontrollable latent partition. This is highlighted on three types of environments, each with a varying class of controllable and uncontrollable elements. This allows for a precise and visible separation of the latent features, improving interpretability, representation quality and possibly moving towards a basis for building causal relationships between an agent and its environment. The unsupervised learning algorithm consists of both an action-conditioned and a state-only forward predictor, along with a contrastive and an adversarial loss, which isolate and disentangle the controllable versus the non-controllable features. Furthermore, we show an application of learning and planning on the human-interpretable disentangled latent representation, where the properties of disentanglement allow the planning algorithm to operate solely in the controllable partition of the latent representation.


Refer to caption


Refer to caption
Figure 1: Visualization in a maze environment of four random pixel observations s48×48𝑠superscript4848s\in\mathbb{R}^{48\times 48}italic_s ∈ blackboard_R start_POSTSUPERSCRIPT 48 × 48 end_POSTSUPERSCRIPT (left) and the encoded observations z=f(s;θenc)s𝒮𝑧𝑓𝑠subscript𝜃𝑒𝑛𝑐for-all𝑠𝒮z=f(s;\theta_{enc})\hskip 5.69054pt\forall s\in\mathcal{S}italic_z = italic_f ( italic_s ; italic_θ start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT ) ∀ italic_s ∈ caligraphic_S (right). On the right, we can see the disentanglement of the controllable latent zc2superscript𝑧𝑐superscript2z^{c}\in\mathbb{R}^{2}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT on the horizontal axes, and the uncontrollable latent zu1superscript𝑧𝑢superscript1z^{u}\in\mathbb{R}^{1}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT on the vertical axis. The encoder is trained on high-dimensional tuples (st,at,rt,st+1)subscript𝑠𝑡subscript𝑎𝑡subscript𝑟𝑡subscript𝑠𝑡1(s_{t},a_{t},r_{t},s_{t+1})( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ), sampled from a replay buffer \mathcal{B}caligraphic_B, gathered from random trajectories in the four maze environments shown on the left. All possible states in all four mazes are encoded and plotted with the transition prediction for each possible action, revealing a clear disentanglement between the controllable latents (agent x-y position) and the uncontrollable latent (wall architecture). Note that all samples are taken from the same buffer, filled with samples from all four mazes.

II Related Work

General Representation Learning

Many works have focused on converting high-dimensional inputs to a compact, abstract latent representation. Learning this representation can make use of auxiliary, unsupervised tasks in addition to the pure RL objectives [3]. One way to ensure a meaningful latent space is to implement architectures that require a pixel reconstruction loss such as a variational [14, 15] or a deterministic [6] autoencoder. Other approaches combined reward reconstruction with latent prediction [16], pixel reconstruction with planning [17, 18] or used latent predictive losses without pixel reconstruction [5, 7].

Representing controllable features

In representation learning for RL, a focus on controllable features can be beneficial as these features are strongly influenced by the policy [10]. This can be done using generative methods [19], but is most commonly pursued using an auxiliary inverse-prediction loss; predicting the action that was taken in the MDP [2]. The work in [20, 21] builds a latent representation with an emphasis on the controllable features of an environment with inverse-prediction losses, and uses these features to guide exploratory behavior. Furthermore, [22] and concurrent work by [23] employ multi-step inverse prediction to successfully encompass controllable features in their representation. However, these works have not expressed a focus on also retaining the uncontrollable features in their representation, which is a key aspect in our work.

Partitioning a latent representation

Sharing similarity in terms of the separation of the latent representation, [24] disentangle the latent representation in the domain adaptation setting into a task-relevant and a context partition, by means of adversarial predictions with gradient reversals and cyclic reconstruction. [25] use a reconstruction-based adversarial architecture that divides their latent representation into reward-relevant and irrelevant features. Related work by [26] further divides the latent representation of Dreamer [17], using action-conditioned and state-only forward predictors, into controllable, uncontrollable and their respective reward relevant and irrelevant features. As compared to [26], who focus on distraction-efficient RL, we purely focus on the representational learning aspect of these predictors, and show notions of separation in low-dimensional, structured representations of MDPs, leaning towards enhanced interpretability. Furthermore, we use an adversarial loss to enforce disentanglement between zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT and zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT, and apply a contrastive loss instead of pixel reconstruction to avoid representation collapse due to latent forward prediction.

Interpretable representations in MDPs

More closely related to our research is the work by [10], which connects individual latent dimensions to independently controllable states in a maze using a reconstruction loss and a selectivity loss. The work by [9] visualizes the representation of an agent and its transitions in a maze environment, but does not disentangle the agent state in its controllable and uncontrollable parts, which limits the interpretability analysis and does not allow simplifications during planning. The work by [11] uses an object-oriented approach to isolate different (controllable) features, using graph neural networks (GNN’s) and a contrastive forward prediction loss, but does not discriminate between controllable and uncontrollable features. Further work in this direction by [12] focuses on theoretical foundations for an encoder to structurally represent a distinct controllable object. We aim to progress the aforementioned lines of research by using a representation learning architecture that disentangles an MDP’s latent representation into interpretable, disentangled controllable and uncontrollable features. Finally, we show that having separate partitions of controllable and uncontrollable features can be exploited in a planning algorithm. Exploitations like these are done in combination with prior knowledge of a certain MDP, as in [27].


Refer to caption

Figure 2: Overview of the disentangling architecture, with dashed lines representing gradient propagation and green rectangles representing parameterized prediction functions. An observation stsubscript𝑠𝑡s_{t}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is encoded into a latent representation consisting of two parts; ztcsubscriptsuperscript𝑧𝑐𝑡z^{c}_{t}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and ztusubscriptsuperscript𝑧𝑢𝑡z^{u}_{t}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, which represent controllable and uncontrollable features respectively. These separated representations are then independently used to make action-conditioned, state-only and adversarial predictions in order to provide gradients to the encoder that disentangle the latent representation ztsubscript𝑧𝑡z_{t}italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT into controllable (ztcsubscriptsuperscript𝑧𝑐𝑡z^{c}_{t}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT) and uncontrollable (ztusubscriptsuperscript𝑧𝑢𝑡z^{u}_{t}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT) partitions.

III Preliminaries

We consider an agent acting within an environment, where the environment is modeled as a discrete Markov Decision Process (MDP) defined as a tuple (𝒮,𝒜,T,R,γ)𝒮𝒜𝑇𝑅𝛾(\mathcal{S},\mathcal{A},T,R,\gamma)( caligraphic_S , caligraphic_A , italic_T , italic_R , italic_γ ). Here, 𝒮𝒮\mathcal{S}caligraphic_S is the state space, 𝒜𝒜\mathcal{A}caligraphic_A is the action space, T:𝒮×𝒜𝒮:𝑇𝒮𝒜𝒮T:\mathcal{S}\times\mathcal{A}\rightarrow\mathcal{S}italic_T : caligraphic_S × caligraphic_A → caligraphic_S is the environment’s transition function, R:𝒮×𝒜:𝑅𝒮𝒜R:\mathcal{S}\times\mathcal{A}\rightarrow\mathcal{R}italic_R : caligraphic_S × caligraphic_A → caligraphic_R is the environment’s reward mapping and γ𝛾\gammaitalic_γ is the discount factor. We consider the setting where we have access to a replay buffer (\mathcal{B}caligraphic_B) of visited states st𝒮subscript𝑠𝑡𝒮s_{t}\in\mathcal{S}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ caligraphic_S that were followed by actions at𝒜subscript𝑎𝑡𝒜a_{t}\in\mathcal{A}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ caligraphic_A and resulted in the rewards rtsubscript𝑟𝑡r_{t}\in\mathcal{R}italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ caligraphic_R and the next states st+1subscript𝑠𝑡1s_{t+1}italic_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT. One entry in B𝐵Bitalic_B contains a tuple of past experience (st,at,rt,st+1)subscript𝑠𝑡subscript𝑎𝑡subscript𝑟𝑡subscript𝑠𝑡1(s_{t},a_{t},r_{t},s_{t+1})( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ). The agent’s goal is to learn a policy π:𝒮𝒜:𝜋𝒮𝒜\pi:\mathcal{S}\rightarrow\mathcal{A}italic_π : caligraphic_S → caligraphic_A that maximizes the expectation of the discounted return Vπ(s)=𝔼τ[t=0γtR(st,at)st=s]superscript𝑉𝜋𝑠subscript𝔼𝜏conditionalsuperscriptsubscript𝑡0superscript𝛾𝑡𝑅subscript𝑠𝑡subscript𝑎𝑡subscript𝑠𝑡𝑠V^{\pi}(s)=\operatorname{\mathbb{E}}_{\tau}[\sum_{t=0}^{\infty}\gamma^{t}R(s_{% t},a_{t})\mid s_{t}=s]italic_V start_POSTSUPERSCRIPT italic_π end_POSTSUPERSCRIPT ( italic_s ) = blackboard_E start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT [ ∑ start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT italic_γ start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_R ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∣ italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_s ], where τ𝜏\tauitalic_τ is a trajectory following the policy π𝜋\piitalic_π.

Furthermore, we examine the setting where a high-dimensional state (stvsubscript𝑠𝑡superscript𝑣s_{t}\in\mathbb{R}^{v}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_v end_POSTSUPERSCRIPT) is compressed into a lower-dimensional latent state zt𝒵=wsubscript𝑧𝑡𝒵superscript𝑤z_{t}\in\mathcal{Z}=\mathbb{R}^{w}italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ caligraphic_Z = blackboard_R start_POSTSUPERSCRIPT italic_w end_POSTSUPERSCRIPT where 𝒵𝒵\mathcal{Z}caligraphic_Z represents the latent space with wv𝑤𝑣w\leq vitalic_w ≤ italic_v. This is done by means of a neural network encoding f:𝒮𝒵:𝑓𝒮𝒵f:\mathcal{S}\rightarrow\mathcal{Z}italic_f : caligraphic_S → caligraphic_Z where f𝑓fitalic_f represents the encoder.

IV Algorithm

We aim for an interpretable and disentangled representation of the controllable and uncontrollable latent features. We define controllable features as the characteristics of the MDP that are predominantly affected by any action a𝒜𝑎𝒜a\in\mathcal{A}italic_a ∈ caligraphic_A, such as the position of the agent in the context of a maze environment. The uncontrollable features are those attributes that are not or only marginally affected by the actions. We show that the proposed disentanglement is possible by designing losses and gradient propagation through two separate parts of the latent representation. Specifically, to assign controllable information to the controllable latent partition, the gradient from an action-conditioned forward predictor is propagated through it. To assign uncontrollable information to the uncontrollable latent partition, the gradient from a state-only forward predictor is propagated through it. The remaining details will be provided in the rest of this Section.

We consider environments with high-dimensional states, represented as pixel inputs. These pixel inputs are subsequently encoded into a latent representation zt=(zc,zu)𝒵nc+nusubscript𝑧𝑡superscript𝑧𝑐superscript𝑧𝑢𝒵superscriptsubscript𝑛𝑐superscriptsubscript𝑛𝑢z_{t}=(z^{c},z^{u})\in\mathcal{Z}\in\mathbb{R}^{n_{c}}+\mathbb{R}^{n_{u}}italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ( italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT , italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ) ∈ caligraphic_Z ∈ blackboard_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUPERSCRIPT + blackboard_R start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, with the superscripts c𝑐citalic_c and u𝑢uitalic_u representing the controllable and uncontrollable features, and the superscripts ncsubscript𝑛𝑐n_{c}italic_n start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT and nusubscript𝑛𝑢n_{u}italic_n start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT representing their respective dimensions. The compression into a latent representation 𝒮𝒵𝒮𝒵\mathcal{S}\rightarrow\mathcal{Z}caligraphic_S → caligraphic_Z is done by means of a convolutional encoder, parameterized by a set of learnable parameters  θencsubscript𝜃𝑒𝑛𝑐\theta_{enc}italic_θ start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT according to:

zt=(ztc,ztu)=f(st;θenc).subscript𝑧𝑡subscriptsuperscript𝑧𝑐𝑡subscriptsuperscript𝑧𝑢𝑡𝑓subscript𝑠𝑡subscript𝜃𝑒𝑛𝑐z_{t}=(z^{c}_{t},z^{u}_{t})=f(s_{t};\theta_{enc}).italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ( italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = italic_f ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT ) . (1)

An overview of the proposed algorithm is illustrated in Fig. 2 and the details are provided hereafter. In this section, all losses and transitions are given under the assumption of a continuous abstract representation and a deterministic transition function. The algorithm could be adapted by replacing the losses related to the internal transitions with generative approaches (in the context of continuous and stochastic transitions) or a log-likelihood loss (in the context of stochastic but discrete representations).

IV-A Controllable Features

To isolate controllable features in the latent representation, ztcsubscriptsuperscript𝑧𝑐𝑡z^{c}_{t}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is used to make an action-conditioned forward prediction in latent space. In the context of a continuous latent space and deterministic transitions, zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT is updated using a mean squared error (MSE) forward prediction loss c=|z^t+1czt+1c|2subscript𝑐superscriptsubscriptsuperscript^𝑧𝑐𝑡1subscriptsuperscript𝑧𝑐𝑡12\mathcal{L}_{c}=\big{|}\hat{z}^{c}_{t+1}-z^{c}_{t+1}\big{|}^{2}caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT = | over^ start_ARG italic_z end_ARG start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT - italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, where z^t+1csubscriptsuperscript^𝑧𝑐𝑡1\hat{z}^{c}_{t+1}over^ start_ARG italic_z end_ARG start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT is the action-conditioned residual forward prediction of the parameterized function Tc(z,a;θc):𝒵×𝒜𝒵:subscript𝑇𝑐𝑧𝑎subscript𝜃𝑐𝒵𝒜𝒵T_{c}(z,a;\theta_{c}):\mathcal{Z}\times\mathcal{A}\rightarrow\mathcal{Z}italic_T start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ( italic_z , italic_a ; italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) : caligraphic_Z × caligraphic_A → caligraphic_Z:

z^t+1c=Tc(zt,at;θc)+ztcsubscriptsuperscript^𝑧𝑐𝑡1subscript𝑇𝑐subscript𝑧𝑡subscript𝑎𝑡subscript𝜃𝑐subscriptsuperscript𝑧𝑐𝑡\hat{z}^{c}_{t+1}=T_{c}(z_{t},a_{t};\theta_{c})+z^{c}_{t}over^ start_ARG italic_z end_ARG start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT = italic_T start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) + italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (2)

and the prediction target zt+1csubscriptsuperscript𝑧𝑐𝑡1z^{c}_{t+1}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT is part of the encoder output f(st+1;θenc)𝑓subscript𝑠𝑡1subscript𝜃𝑒𝑛𝑐f(s_{t+1};\theta_{enc})italic_f ( italic_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT ). Note that the full latent state ztsubscript𝑧𝑡z_{t}italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is necessary in order to predict z^t+1csubscriptsuperscript^𝑧𝑐𝑡1\hat{z}^{c}_{t+1}over^ start_ARG italic_z end_ARG start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT (e.g. the uncontrollable features could represent a wall or other static structure that is necessary for the prediction of the controllable features). Furthermore, the uncontrollable latent partition input ztusubscriptsuperscript𝑧𝑢𝑡z^{u}_{t}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is accompanied by a stop gradient to discourage the presence of controllable features in zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT. When minimizing csubscript𝑐\mathcal{L}_{c}caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT, both the encoder (θencsubscript𝜃𝑒𝑛𝑐\theta_{enc}italic_θ start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT) as well as the predictor (θcsubscript𝜃𝑐\theta_{c}italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT) are updated, which allows shaping the representation zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT as well as learning the internal dynamics.

IV-B Uncontrollable Features

To express uncontrollable features in the latent space, ztusubscriptsuperscript𝑧𝑢𝑡z^{u}_{t}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is used to make a state-only (not conditioned on the action atsubscript𝑎𝑡a_{t}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT) forward prediction in latent space. This enforces uncontrollable features within the uncontrollable latent partition zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT, since features that are action-dependent cannot be accurately predicted with the preceding state only. Following a residual prediction, zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT is then updated using a MSE forward prediction loss u=|z^t+1uzt+1u|2subscript𝑢superscriptsubscriptsuperscript^𝑧𝑢𝑡1subscriptsuperscript𝑧𝑢𝑡12\mathcal{L}_{u}=\big{|}\hat{z}^{u}_{t+1}-z^{u}_{t+1}\big{|}^{2}caligraphic_L start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT = | over^ start_ARG italic_z end_ARG start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT - italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, with z^t+1usubscriptsuperscript^𝑧𝑢𝑡1\hat{z}^{u}_{t+1}over^ start_ARG italic_z end_ARG start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT defined as:

z^t+1u=Tu(ztu;θu)+ztusubscriptsuperscript^𝑧𝑢𝑡1subscript𝑇𝑢subscriptsuperscript𝑧𝑢𝑡subscript𝜃𝑢subscriptsuperscript𝑧𝑢𝑡\hat{z}^{u}_{t+1}=T_{u}(z^{u}_{t};\theta_{u})+z^{u}_{t}over^ start_ARG italic_z end_ARG start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT = italic_T start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT ( italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT ) + italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (3)

and Tu(zu;θu):𝒵𝒵:subscript𝑇𝑢superscript𝑧𝑢subscript𝜃𝑢𝒵𝒵T_{u}(z^{u};\theta_{u}):\mathcal{Z}\rightarrow\mathcal{Z}italic_T start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT ( italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT ) : caligraphic_Z → caligraphic_Z representing the parameterized prediction function. The target zt+1usubscriptsuperscript𝑧𝑢𝑡1z^{u}_{t+1}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT is part of the output of the encoder f(st+1;θenc)𝑓subscript𝑠𝑡1subscript𝜃𝑒𝑛𝑐f(s_{t+1};\theta_{enc})italic_f ( italic_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT ). When minimizing usubscript𝑢\mathcal{L}_{u}caligraphic_L start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT, both θencsubscript𝜃𝑒𝑛𝑐\theta_{enc}italic_θ start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT and θusubscript𝜃𝑢\theta_{u}italic_θ start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT are updated. In this way the loss usubscript𝑢\mathcal{L}_{u}caligraphic_L start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT drives the latent representation zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT, which is conditioned on θencsubscript𝜃𝑒𝑛𝑐\theta_{enc}italic_θ start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT according to (ztc,ztu)=f(st;θenc)subscriptsuperscript𝑧𝑐𝑡subscriptsuperscript𝑧𝑢𝑡𝑓subscript𝑠𝑡subscript𝜃𝑒𝑛𝑐(z^{c}_{t},z^{u}_{t})=f(s_{t};\theta_{enc})( italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = italic_f ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT ), to only represent the features of stsubscript𝑠𝑡s_{t}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT that are not conditioned on the action atsubscript𝑎𝑡a_{t}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

IV-C Avoiding Predictive Representation Collapse

Minimizing a forward prediction loss in latent space 𝒵𝒵\mathcal{Z}caligraphic_Z is prone to collapse [9, 16], due to the convergence of csubscript𝑐\mathcal{L}_{c}caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT and usubscript𝑢\mathcal{L}_{u}caligraphic_L start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT when f(st;θenc)𝑓subscript𝑠𝑡subscript𝜃𝑒𝑛𝑐f(s_{t};\theta_{enc})italic_f ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT ) is a constant    st𝒮for-allsubscript𝑠𝑡𝒮\forall\hskip 5.69054pts_{t}\in\mathcal{S}∀ italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ caligraphic_S. To avoid representation collapse when using forward predictors, a contrastive loss is used to enforce sufficient diversity in the latent representation:

H1=exp(Cdztz¯t2)subscriptsubscript𝐻1𝑒𝑥𝑝subscript𝐶𝑑subscriptnormsubscript𝑧𝑡subscript¯𝑧𝑡2\mathcal{L}_{H_{1}}=exp\big{(}-C_{d}\big{\|}z_{t}-\bar{z}_{t}\big{\|}_{2}\big{)}caligraphic_L start_POSTSUBSCRIPT italic_H start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT = italic_e italic_x italic_p ( - italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ∥ italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - over¯ start_ARG italic_z end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) (4)

where Cdsubscript𝐶𝑑C_{d}italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT represents a constant hyperparameter and z¯tsubscript¯𝑧𝑡\bar{z}_{t}over¯ start_ARG italic_z end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is a ‘negative’ batch of latent states ztsubscript𝑧𝑡z_{t}italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, which is obtained by shifting each position of latent states in the batch by a random number between 0 and the batch size. In the random maze environment, an additional contrastive loss is added to further diversify the controllable representation:

H2=exp(Cdztcz¯tc2)subscriptsubscript𝐻2𝑒𝑥𝑝subscript𝐶𝑑subscriptnormsubscriptsuperscript𝑧𝑐𝑡subscriptsuperscript¯𝑧𝑐𝑡2\mathcal{L}_{H_{2}}=exp\big{(}-C_{d}\big{\|}z^{c}_{t}-\bar{z}^{c}_{t}\big{\|}_% {2}\big{)}caligraphic_L start_POSTSUBSCRIPT italic_H start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT = italic_e italic_x italic_p ( - italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ∥ italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - over¯ start_ARG italic_z end_ARG start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) (5)

where ztcsubscriptsuperscript𝑧𝑐𝑡z^{c}_{t}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is obtained from randomly sampled trajectories. This additional regularizer proved neccessary to avoid collapse of zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT when moving to a near infinite number of possible mazes. More information on this subject can be found in Appendix A-D. The resulting contrastive loss Hsubscript𝐻\mathcal{L}_{H}caligraphic_L start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT for the random maze environment then consists of 0.5H1+0.5H20.5subscriptsubscript𝐻10.5subscriptsubscript𝐻20.5\mathcal{L}_{H_{1}}+0.5\mathcal{L}_{H_{2}}0.5 caligraphic_L start_POSTSUBSCRIPT italic_H start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT + 0.5 caligraphic_L start_POSTSUBSCRIPT italic_H start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT. The total loss used to update the encoder’s parameters now consists of enc=c+u+Hsubscript𝑒𝑛𝑐subscript𝑐subscript𝑢subscript𝐻\mathcal{L}_{enc}=\mathcal{L}_{c}+\mathcal{L}_{u}+\mathcal{L}_{H}caligraphic_L start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT = caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT.

IV-D Guiding Feature Disentanglement with Adversarial Loss

When using a controllable latent space zcx,xformulae-sequencesuperscript𝑧𝑐superscript𝑥𝑥z^{c}\in\mathbb{R}^{x},x\in\mathbb{N}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_x end_POSTSUPERSCRIPT , italic_x ∈ blackboard_N, where x>g𝑥𝑔x>gitalic_x > italic_g, with g𝑔gitalic_g representing the number of dimensions needed to portray the controllable features, some information about the uncontrollable features in the controllable latent representation might be present (see Appendix C-B). This is due to the non-enforcing nature of csubscript𝑐\mathcal{L}_{c}caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT, as the uncontrollable features are equally predictable with or without the action. To ensure that no information about the uncontrollable features is kept in the controllable latent representation, an adversarial component is added to the architecture in Fig. 2. This is done by updating the encoder with an adversarial loss advsubscript𝑎𝑑𝑣\mathcal{L}_{adv}caligraphic_L start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT and reversing the gradient [28]. The adversarial loss is defined as

adv=|z^tuztu|2,subscript𝑎𝑑𝑣superscriptsubscriptsuperscript^𝑧𝑢𝑡subscriptsuperscript𝑧𝑢𝑡2\mathcal{L}_{adv}=\big{|}\hat{z}^{u}_{t}-z^{u}_{t}\big{|}^{2},caligraphic_L start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT = | over^ start_ARG italic_z end_ARG start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , (6)

with z^tu=Tadv(ztc;θadv)subscriptsuperscript^𝑧𝑢𝑡subscript𝑇𝑎𝑑𝑣subscriptsuperscript𝑧𝑐𝑡subscript𝜃𝑎𝑑𝑣\hat{z}^{u}_{t}=T_{adv}(z^{c}_{t};\theta_{adv})over^ start_ARG italic_z end_ARG start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_T start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT ( italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT ), where z^tusubscriptsuperscript^𝑧𝑢𝑡\hat{z}^{u}_{t}over^ start_ARG italic_z end_ARG start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is the uncontrollable prediction of the parameterized function Tadv(zc;θadv):𝒵𝒵:subscript𝑇𝑎𝑑𝑣superscript𝑧𝑐subscript𝜃𝑎𝑑𝑣𝒵𝒵T_{adv}(z^{c};\theta_{adv}):\mathcal{Z}\rightarrow\mathcal{Z}italic_T start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT ( italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT ) : caligraphic_Z → caligraphic_Z and ztusubscriptsuperscript𝑧𝑢𝑡z^{u}_{t}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is the target. Intuitively, since the parameters of Tadv(zc;θadv)subscript𝑇𝑎𝑑𝑣superscript𝑧𝑐subscript𝜃𝑎𝑑𝑣T_{adv}(z^{c};\theta_{adv})italic_T start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT ( italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT ) are being updated with advsubscript𝑎𝑑𝑣\mathcal{L}_{adv}caligraphic_L start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT and the parameters of f(s;θenc)𝑓𝑠subscript𝜃𝑒𝑛𝑐f(s;\theta_{enc})italic_f ( italic_s ; italic_θ start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT ) are being updated with advsubscript𝑎𝑑𝑣-\mathcal{L}_{adv}- caligraphic_L start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT, the prediction function can be seen as the discriminator and the encoder can be seen as the generator [29]. The discriminator tries to give an accurate prediction of the uncontrollable latent zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT given the controllable latent zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT, while the generator tries to counteract the discriminator by removing any uncontrollable features from the controllable representation. In our case, the predictor is a multi-layer perceptron (MLP), which means that minimizing advsubscript𝑎𝑑𝑣\mathcal{L}_{adv}caligraphic_L start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT enforces that no nonlinear relation between zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT and zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT can be learned. We hypothesize that this is a deterministic approximation of minimizing the Mutual Information (MI) between zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT and zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT. When using the adversarial loss, the combined loss propagating through the encoder consists of enc=c+u+Hadvsubscript𝑒𝑛𝑐subscript𝑐subscript𝑢subscript𝐻subscript𝑎𝑑𝑣\mathcal{L}_{enc}=\mathcal{L}_{c}+\mathcal{L}_{u}+\mathcal{L}_{H}-\mathcal{L}_% {adv}caligraphic_L start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT = caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT - caligraphic_L start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT. Here the minus term in advsubscript𝑎𝑑𝑣-\mathcal{L}_{adv}- caligraphic_L start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT represents a gradient reversal to the encoder. Note that the losses are not scaled, as this did not prove to be necessary for the experiments conducted.

Algorithm 1 Disentangled (Un)Controllable Features
1:Initialize θencsubscript𝜃𝑒𝑛𝑐\theta_{enc}italic_θ start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT, θcsubscript𝜃𝑐\theta_{c}italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT, θusubscript𝜃𝑢\theta_{u}italic_θ start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT, θadvsubscript𝜃𝑎𝑑𝑣\theta_{adv}italic_θ start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT
2:for iteration=1,2,,N𝑖𝑡𝑒𝑟𝑎𝑡𝑖𝑜𝑛12𝑁iteration=1,2,\ldots,Nitalic_i italic_t italic_e italic_r italic_a italic_t italic_i italic_o italic_n = 1 , 2 , … , italic_N do
3:     Sample batch of tuples {st,at,st+1subscript𝑠𝑡subscript𝑎𝑡subscript𝑠𝑡1s_{t},a_{t},s_{t+1}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT}
4:     Encode observations: f(s;θenc)={zc,zu}𝑓𝑠subscript𝜃𝑒𝑛𝑐superscript𝑧𝑐superscript𝑧𝑢f(s;\theta_{enc})=\{z^{c},z^{u}\}italic_f ( italic_s ; italic_θ start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT ) = { italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT , italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT }
5:     Predict z^t+1c=Tc(ztc,ztu,a;θc)+ztcsubscriptsuperscript^𝑧𝑐𝑡1subscript𝑇𝑐subscriptsuperscript𝑧𝑐𝑡subscriptsuperscript𝑧𝑢𝑡𝑎subscript𝜃𝑐subscriptsuperscript𝑧𝑐𝑡\hat{z}^{c}_{t+1}=T_{c}(z^{c}_{t},z^{u}_{t},a;\theta_{c})+z^{c}_{t}\hskip 56.9% 055ptover^ start_ARG italic_z end_ARG start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT = italic_T start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ( italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a ; italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) + italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT // detach ztusubscriptsuperscript𝑧𝑢𝑡z^{u}_{t}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
6:     Predict z^t+1u=Tu(ztu;θu)+ztusubscriptsuperscript^𝑧𝑢𝑡1subscript𝑇𝑢subscriptsuperscript𝑧𝑢𝑡subscript𝜃𝑢subscriptsuperscript𝑧𝑢𝑡\hat{z}^{u}_{t+1}=T_{u}(z^{u}_{t};\theta_{u})+z^{u}_{t}over^ start_ARG italic_z end_ARG start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT = italic_T start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT ( italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT ) + italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT
7:     Predict z^tu=Tadv(ztc;θadv)subscriptsuperscript^𝑧𝑢𝑡subscript𝑇𝑎𝑑𝑣subscriptsuperscript𝑧𝑐𝑡subscript𝜃𝑎𝑑𝑣\hat{z}^{u}_{t}=T_{adv}(z^{c}_{t};\theta_{adv})over^ start_ARG italic_z end_ARG start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_T start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT ( italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT )
8:     Compute losses c,u,adv,Hsubscript𝑐subscript𝑢subscript𝑎𝑑𝑣subscript𝐻\mathcal{L}_{c},\mathcal{L}_{u},-\mathcal{L}_{adv},\mathcal{L}_{H}caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , caligraphic_L start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT , - caligraphic_L start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT , caligraphic_L start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT       
9:     Update parameters θencsubscript𝜃𝑒𝑛𝑐\theta_{enc}italic_θ start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT, θcsubscript𝜃𝑐\theta_{c}italic_θ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT, θusubscript𝜃𝑢\theta_{u}italic_θ start_POSTSUBSCRIPT italic_u end_POSTSUBSCRIPT, θadvsubscript𝜃𝑎𝑑𝑣\theta_{adv}italic_θ start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT
10:end for

IV-E Downstream Tasks

By disentangling a latent representation in a controllable and an uncontrollable part, one can more readily obtain human-interpretable features. While interpretability is generally an important aspect, it is also important to test how a notion of human interpretability affects downstream performance, as it is generally desired to strike a good balance between interpretability and performance. This is examined by training an RL agent on the learned and subsequently frozen latent representation. The action atsubscript𝑎𝑡a_{t}italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is chosen following an ϵitalic-ϵ\epsilonitalic_ϵ-greedy policy, where a random action is taken with a probability ϵitalic-ϵ\epsilonitalic_ϵ, and with (1ϵ)1italic-ϵ(1-\epsilon)( 1 - italic_ϵ ) probability the policy π(z)=argmaxa𝒜Q(z,a;θ)𝜋𝑧𝑎𝒜argmax𝑄𝑧𝑎𝜃\pi(z)=\underset{a\in\mathcal{A}}{\operatorname*{arg\,max}}\hskip 2.84526ptQ(z% ,a;\theta)italic_π ( italic_z ) = start_UNDERACCENT italic_a ∈ caligraphic_A end_UNDERACCENT start_ARG roman_arg roman_max end_ARG italic_Q ( italic_z , italic_a ; italic_θ ) is evaluated, where Q(z,a;θ)𝑄𝑧𝑎𝜃Q(z,a;\theta)italic_Q ( italic_z , italic_a ; italic_θ ) is the Q-network trained by Deep Double Q-Learning (DDQN) [30, 31]. The Q-network is trained with respect to a target Ytsubscript𝑌𝑡Y_{t}italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT:

Yt=rt+γQ(zt+1,argmaxa𝒜Q(zt+1,a;θ);θ).subscript𝑌𝑡subscript𝑟𝑡𝛾𝑄subscript𝑧𝑡1subscriptargmax𝑎𝒜𝑄subscript𝑧𝑡1𝑎𝜃superscript𝜃Y_{t}=r_{t}+\gamma Q(z_{t+1},\operatorname*{arg\,max}_{a\in\mathcal{A}}Q(z_{t+% 1},a;\theta);\theta^{-})\,.italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_γ italic_Q ( italic_z start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT , start_OPERATOR roman_arg roman_max end_OPERATOR start_POSTSUBSCRIPT italic_a ∈ caligraphic_A end_POSTSUBSCRIPT italic_Q ( italic_z start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT , italic_a ; italic_θ ) ; italic_θ start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT ) . (7)

With γ𝛾\gammaitalic_γ representing the environment’s discount factor and θsuperscript𝜃\theta^{-}italic_θ start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT the target Q-network’s parameters. The target Q-network’s parameters are updated as an exponential moving average of the original parameters θ𝜃\thetaitalic_θ according to: θk+1=(1τ)θk+τθksubscriptsuperscript𝜃𝑘11𝜏subscriptsuperscript𝜃𝑘𝜏subscript𝜃𝑘\theta^{-}_{k+1}=(1-\tau)\theta^{-}_{k}+\tau\theta_{k}italic_θ start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT = ( 1 - italic_τ ) italic_θ start_POSTSUPERSCRIPT - end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_τ italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, where subscript k𝑘kitalic_k represents a training iteration and τ𝜏\tauitalic_τ represents a hyperparameter controlling the speed of the parameter update. The resulting DDQN loss is defined as Q=|YtQ(zt,a;θ)|2subscript𝑄superscriptsubscript𝑌𝑡𝑄subscript𝑧𝑡𝑎𝜃2\mathcal{L}_{Q}=\big{|}Y_{t}-Q(z_{t},a;\theta)\big{|}^{2}caligraphic_L start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT = | italic_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_Q ( italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a ; italic_θ ) | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. The full computation of all losses is shown in pseudocode in Algorithm 1.

V Experiments

In this section, we showcase the disentanglement of controllable and uncontrollable features on three different environments, the complexity of which is in line with prior work on structured representations [10, 32, 9, 11, 12]: (i) a quadruple maze environment, (ii) the catcher environment and (iii) a randomly generated maze environment. The first environment yields a state space of 119 different observations, and is used to showcase the algorithm’s ability to disentangle a low-dimensional latent representation. The catcher environment examines a setting where the uncontrollable features are not static, and the random maze environment is used to showcase disentanglement in a more complex distribution of over 25 million possible environments, followed by the application of downstream tasks by applying reinforcement learning (DDQN) and a latent planning algorithm running in the controllable latent partition . The base of the encoder is derived from [33] and consists of two convolutional layers, followed by a fully connected layer for low-dimensional latent representations or an additional CNN for a higher-dimensional latent representation such as a feature map. For the full network architectures, we refer the reader to Appendix C. In all environments, the encoder f(s;θenc)𝑓𝑠subscript𝜃𝑒𝑛𝑐f(s;\theta_{enc})italic_f ( italic_s ; italic_θ start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT ) is trained from a buffer \mathcal{B}caligraphic_B filled with transition tuples (st,at,rt,st+1)subscript𝑠𝑡subscript𝑎𝑡subscript𝑟𝑡subscript𝑠𝑡1(s_{t},a_{t},r_{t},s_{t+1})( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ) from random trajectories. Note that, in interpretability, there is generally not a specific metric to optimize for. In order to produce interpretable representations, finding the right hyperparameters required manual (human) inspection of the plotted latent representations. An ablation of the hyperparameters used can be found in Appendices A1-A3

V-A Quadruple Maze Environment

The maze environment consists of an agent and a selection of four distinct, handpicked wall architectures. The environment’s state is provided as pixel observations st48×48subscript𝑠𝑡superscript4848s_{t}\in\mathbb{R}^{48\times 48}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 48 × 48 end_POSTSUPERSCRIPT, where an action moves the agent by 6 pixels in each direction (up, down, left, right) except if this direction is obstructed by a wall. We consider the context where there is no reward (rt=0(st,at)(𝒮,𝒜)subscript𝑟𝑡0for-allsubscript𝑠𝑡subscript𝑎𝑡𝒮𝒜r_{t}=0\hskip 5.69054pt\forall\hskip 5.69054pt(s_{t},a_{t})\in(\mathcal{S},% \mathcal{A})italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 0 ∀ ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∈ ( caligraphic_S , caligraphic_A )) and there is no terminal state.

We select a two-dimensional controllable representation (zc2superscript𝑧𝑐superscript2z^{c}\in\mathbb{R}^{2}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT) and a one-dimensional uncontrollable representation (zu1superscript𝑧𝑢superscript1z^{u}\in\mathbb{R}^{1}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT). The remaining hyperparameters and details can be found in Appendix B. The experiments are conducted using a buffer \mathcal{B}caligraphic_B filled with random trajectories from the four different basic maze architectures. The encoder’s parameters are updated using encsubscript𝑒𝑛𝑐\mathcal{L}_{enc}caligraphic_L start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT in Section IV-C with H=H1subscript𝐻subscriptsubscript𝐻1\mathcal{L}_{H}=\mathcal{L}_{H_{1}}caligraphic_L start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT = caligraphic_L start_POSTSUBSCRIPT italic_H start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT. After 50k training iterations, a clear disentanglement between the controllable (zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT) and uncontrollable (zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT) latent representation can be seen in Fig. 1. One can observe that the encoder is updated so that the one-dimensional latent representation zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT learns different values that define the type of wall architecture. A progression to this representation is provided in Appendix C-A.

0pt Refer to caption

(a) Without advsubscript𝑎𝑑𝑣\mathcal{L}_{adv}caligraphic_L start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT

0pt Refer to caption

(b) With advsubscript𝑎𝑑𝑣\mathcal{L}_{adv}caligraphic_L start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT
Figure 3: Visualization of the latent feature disentanglement in the catcher environment after 200k training iterations, with zt=f(st;θenc)subscript𝑧𝑡𝑓subscript𝑠𝑡subscript𝜃𝑒𝑛𝑐z_{t}=f(s_{t};\theta_{enc})italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_f ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT ) 2+6×6absentsuperscript2superscript66\in\mathbb{R}^{2}+\mathbb{R}^{6\times 6}∈ blackboard_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + blackboard_R start_POSTSUPERSCRIPT 6 × 6 end_POSTSUPERSCRIPT. In (a) and (b), the left column shows ztcsubscriptsuperscript𝑧𝑐𝑡z^{c}_{t}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, the middle column is a feature map representing ztusubscriptsuperscript𝑧𝑢𝑡z^{u}_{t}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and the right column is the pixel state stsubscript𝑠𝑡s_{t}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. The dashed lines separate observations where the ball position or the paddle position is kept fixed for illustration purposes. zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT tracks the agent position while zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT tracks the falling ball. In b), note that even when having a two-dimensional controllable state (only 1 is needed, see Appendix C-B), the adversarial loss in b) makes sure that distinct ball positions have a negligible effect on zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT (left column), even when the high-level features of the agent and the ball might be hard to distinguish.

V-B Catcher Environment

As opposed to the maze environment, the catcher environment encompasses uncontrollable features that are non-stationary. The ball is dropped randomly at the top of the environment and is falling irrespective of the actions, while the paddle position is directly modified by the actions. The environment’s states are defined as pixel observations stsubscript𝑠𝑡s_{t}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT of size 51×51superscript5151\mathbb{R}^{51\times 51}blackboard_R start_POSTSUPERSCRIPT 51 × 51 end_POSTSUPERSCRIPT. At each time step, the paddle moves left or right by 3 pixels. Since we are only doing unsupervised learning, we consider the context where there is no reward (rt=0(st,at)(𝒮,𝒜)subscript𝑟𝑡0for-allsubscript𝑠𝑡subscript𝑎𝑡𝒮𝒜r_{t}=0\hskip 5.69054pt\forall\hskip 5.69054pt(s_{t},a_{t})\in(\mathcal{S},% \mathcal{A})italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 0 ∀ ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∈ ( caligraphic_S , caligraphic_A )) and an episode ends whenever the ball reaches the paddle or the bottom.

We take zc2superscript𝑧𝑐superscript2z^{c}\in\mathbb{R}^{2}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT and zu6×6superscript𝑧𝑢superscript66z^{u}\in\mathbb{R}^{6\times 6}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 6 × 6 end_POSTSUPERSCRIPT. To test disentanglement, zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT is of a higher dimension than needed since the paddle (agent) only moves on the x-axis and would therefore require only one feature (see Appendix C-B for the simpler setting with zc1superscript𝑧𝑐superscript1z^{c}\in\mathbb{R}^{1}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT). To show disentanglement, the redundant dimension of zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT should not or negligibly have information about zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT. The encoder’s parameters are updated using encsubscript𝑒𝑛𝑐\mathcal{L}_{enc}caligraphic_L start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT in Section IV-D with H=H1subscript𝐻subscriptsubscript𝐻1\mathcal{L}_{H}=\mathcal{L}_{H_{1}}caligraphic_L start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT = caligraphic_L start_POSTSUBSCRIPT italic_H start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT. After training the encoder for 200k iterations, a selection of state observations stsubscript𝑠𝑡s_{t}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and their encoding into the latent representation z=(zc,zu)𝑧superscript𝑧𝑐superscript𝑧𝑢z=(z^{c},z^{u})italic_z = ( italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT , italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ) can be seen in Fig. 3. A clear distinction between the ball and paddle representations can be observed, with the former residing in zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT and the latter in zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT.

V-C Random Maze Environment

The random maze environment is similar to the maze environment from Section V-A, but consists of a large distribution of randomly generated mazes with complex wall structures. The environment’s state is provided as pixel observations st48×48subscript𝑠𝑡superscript4848s_{t}\in\mathbb{R}^{48\times 48}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 48 × 48 end_POSTSUPERSCRIPT, where an action moves the agent by 6 pixels in each direction. We consider zc2superscript𝑧𝑐superscript2z^{c}\in\mathbb{R}^{2}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT and zu6×6superscript𝑧𝑢superscript66z^{u}\in\mathbb{R}^{6\times 6}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 6 × 6 end_POSTSUPERSCRIPT. This environment tests the generalization properties of a disentangled latent representation, as there are over 25252525 million possible maze architectures, corresponding to a probability of less than 41084superscript1084\cdot 10^{-8}4 ⋅ 10 start_POSTSUPERSCRIPT - 8 end_POSTSUPERSCRIPT to sample the same maze twice. Note that because zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT is 2-dimensional, results with and without adversarial loss are in practice extremely close. After 50k training iterations, the latent representation z=(zc,zu)𝑧superscript𝑧𝑐superscript𝑧𝑢z=(z^{c},z^{u})italic_z = ( italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT , italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ) shows an interpretable disentanglement between the controllable and the uncontrollable features (see Fig. 3(a)). A clear distinction between the agent and the wall structure can be found inside zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT and zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT. Note that Instead of using a single dimension to ‘describe’ the uncontrollable features zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT (see Fig. 1), using a feature map for zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT allows training an encoding that provides a more interpretable representation of the actual wall architecture.

Refer to caption
(a) c=csubscript𝑐subscript𝑐\mathcal{L}_{c}=\mathcal{L}_{c}caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT = caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT
Refer to caption
(b) c=invsubscript𝑐subscript𝑖𝑛𝑣\mathcal{L}_{c}=\mathcal{L}_{inv}caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT = caligraphic_L start_POSTSUBSCRIPT italic_i italic_n italic_v end_POSTSUBSCRIPT
Refer to caption
(c) enc=Qsubscript𝑒𝑛𝑐subscript𝑄\mathcal{L}_{enc}=\mathcal{L}_{Q}caligraphic_L start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT = caligraphic_L start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT
Figure 4: A plot of the latent representation for all observations in a single randomly sampled maze when training with the aforementioned losses (a), substituting the action-conditioned forward-prediction loss csubscript𝑐\mathcal{L}_{c}caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT for an inverse-prediction loss invsubscript𝑖𝑛𝑣\mathcal{L}_{inv}caligraphic_L start_POSTSUBSCRIPT italic_i italic_n italic_v end_POSTSUBSCRIPT (b) and when end-to-end updating the encoder with only the Q-loss Qsubscript𝑄\mathcal{L}_{Q}caligraphic_L start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT from DDQN for 500k iterations (c). The left column shows the controllable latent ztc2subscriptsuperscript𝑧𝑐𝑡superscript2z^{c}_{t}\in\mathbb{R}^{2}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT with the current state in blue, the remaining states in red, and the predicted movement due to actions as different colored bars for each individual action. The middle column shows the uncontrollable latent ztu6×6subscriptsuperscript𝑧𝑢𝑡superscript66z^{u}_{t}\in\mathbb{R}^{6\times 6}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 6 × 6 end_POSTSUPERSCRIPT and the right column shows the original state st48×48subscript𝑠𝑡superscript4848s_{t}\in\mathbb{R}^{48\times 48}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 48 × 48 end_POSTSUPERSCRIPT. Evidently, the controllable representations in (b) and (c) lack disentanglement and interpretability. Furthermore, the representation in (c) seems to have very little structure at all, showing that a representation that is optimized without prior structural incentives will often represent a black box.

Using an Inverse Predictor

An alternative to the state-action forward prediction method used throughout the paper is the inverse (action) prediction loss. An inverse prediction loss is often referred to in previous work that focuses on controllable features [2, 20, 21]. A single-step inverse prediction loss is defined as:

a^t=I(ztc,zt+1c,ztu;θinv).subscript^𝑎𝑡𝐼subscriptsuperscript𝑧𝑐𝑡subscriptsuperscript𝑧𝑐𝑡1subscriptsuperscript𝑧𝑢𝑡subscript𝜃𝑖𝑛𝑣\hat{a}_{t}=I(z^{c}_{t},z^{c}_{t+1},z^{u}_{t};\theta_{inv}).over^ start_ARG italic_a end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_I ( italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT , italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_i italic_n italic_v end_POSTSUBSCRIPT ) . (8)

Here, a^tsubscript^𝑎𝑡\hat{a}_{t}over^ start_ARG italic_a end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is the predicted action and I(ztc,zt+1c,ztu;θinv):𝒵𝒜:𝐼subscriptsuperscript𝑧𝑐𝑡subscriptsuperscript𝑧𝑐𝑡1subscriptsuperscript𝑧𝑢𝑡subscript𝜃𝑖𝑛𝑣𝒵𝒜I(z^{c}_{t},z^{c}_{t+1},z^{u}_{t};\theta_{inv}):\mathcal{Z}\rightarrow\mathcal% {A}italic_I ( italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT , italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_i italic_n italic_v end_POSTSUBSCRIPT ) : caligraphic_Z → caligraphic_A is the inverse prediction network. To see whether an inverse predictor can generate structured, controllable representations in the random maze environment, we replace the action-conditioned forward predictor with an inverse predictor, so that zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT is no longer updated with csubscript𝑐\mathcal{L}_{c}caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT but with invsubscript𝑖𝑛𝑣\mathcal{L}_{inv}caligraphic_L start_POSTSUBSCRIPT italic_i italic_n italic_v end_POSTSUBSCRIPT (see Appendix A-F for details on invsubscript𝑖𝑛𝑣\mathcal{L}_{inv}caligraphic_L start_POSTSUBSCRIPT italic_i italic_n italic_v end_POSTSUBSCRIPT).

The resulting representation can be seen in Fig. 3(b). It seems that using invsubscript𝑖𝑛𝑣\mathcal{L}_{inv}caligraphic_L start_POSTSUBSCRIPT italic_i italic_n italic_v end_POSTSUBSCRIPT, causes an absence of interpretable structure in the controllable latent representation ztcsubscriptsuperscript𝑧𝑐𝑡z^{c}_{t}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. Furthermore, there is a less precise disentanglement between the controllable and uncontrollable features, as differences can be observed in ztcsubscriptsuperscript𝑧𝑐𝑡z^{c}_{t}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT when encoding equal agent positions as pixel states stsubscript𝑠𝑡s_{t}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. In addition, an inverse predictor does not allow forward prediction in latent space, which can be used for planning as shown hereafter. It thus seems that in some environments, an inverse prediction loss might be insufficient to isolate the controllable features. Take for example the maze agent in the top-right maze of Fig. 4, where the agent can only move in the left direction. Even when using the wall information (ztu)z^{u}_{t})italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ), an inverse predictor will not be able to predict the action taken when the agent does not go left. However, an action-conditioned forward predictor is able to predict the next state correctly regardless of which action was taken.


Refer to caption
Figure 5: Performance of different (pre)trained representations on the random maze environment, measured as a mean (full line) and standard error (shaded area) over 5 seeds. The ‘Interpretable’ setting uses an encoder pre-trained with 50k iterations to acquire a representation as in Fig. 3(a), after which the encoder is frozen and a Q-network is trained on top with DDQN for 500k iterations. The ‘Interpretable + Planning’ curve is similar to the ‘Interpretable’ setting but uses DDQN with a planning algorithm in the controllable partition of the latent space with a depth of 3. The ‘DDQN’ setting uses an encoder trained end-to-end with only DDQN for 500k iterations and the ‘Inverse Prediction’ setting is equal to the ’Interpretable’ setting but has an encoder pre-trained with invsubscript𝑖𝑛𝑣\mathcal{L}_{inv}caligraphic_L start_POSTSUBSCRIPT italic_i italic_n italic_v end_POSTSUBSCRIPT instead of csubscript𝑐\mathcal{L}_{c}caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT.

Reinforcement Learning

In order to verify whether a human-interpretable disentangled latent encoding is informative enough for downstream tasks, we formalize the random maze environment into an MDP with rewards. The agent acquires a reward rtsubscript𝑟𝑡r_{t}italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT of -0.1 at every time step, except when it finds the key in the top right part in which case it acquires a positive reward of 1. The episode ends whenever a positive reward is obtained or a total of 50 environment steps have been taken. For each new episode, a random wall structure is generated, and the agent starts over in the bottom left section of the maze (see Fig. 5). To see whether an interpretable disentangled latent representation is useful for RL, we compare different scenarios of (pre)training; (i) An encoder pretrained for 50k iterations to attain the representation in Fig. 3(a) and subsequently trained with DDQN for 500k iterations (ii) an encoder identical to the aforementioned but trained with DDQN and a planning algorithm (iii) an encoder pretrained for 50k iterations with invsubscript𝑖𝑛𝑣\mathcal{L}_{inv}caligraphic_L start_POSTSUBSCRIPT italic_i italic_n italic_v end_POSTSUBSCRIPT instead of csubscript𝑐\mathcal{L}_{c}caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT and subsequently trained with DDQN for 500k iterations (iv) an encoder purely trained with DDQN gradients for 500k iterations. The resulting performances are compared in Fig. 5. We find that a disentangled structured representation is suitable for downstream tasks, as it achieves comparable performance to training an encoder end-to-end with DDQN for 500k iterations. Although performance is similar, Fig. 3(c) shows that an encoder updated solely with the DDQN gradient can lose any form of interpretability. Moreover, we show in Fig. 5 that a representation trained with an inverse prediction loss instead of a state-action forward prediction loss leads to poor downstream performance in the random maze environment.

Planning

As seen in Fig. 3(a), after pre-training with the unsupervised losses, an interpretable disentangled representation with the corresponding agent transitions is obtained. Due to this disentanglement of the controllable and uncontrollable features, we can for instance employ prior knowledge that the uncontrollable features in the maze environment are static, and employ latent planning in the controllable latent space only (see Fig. 6). The planning algorithm used is derived from [34], and is used to successfully plan only in the controllable partition of the latent representation zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT, while freezing the input for zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT regardless of planning depth. More details on the planning algorithm can be found in Appendix A-E. It can be observed that even when planning with a relatively small depth of 3, we achieve better performance than the pre-trained representation with an ϵitalic-ϵ\epsilonitalic_ϵ-greedy policy and than the purely DDQN-updated encoder.

VI Limitations

While the work presented here provides a step towards a better understanding of disentangling controllable and uncontrollable features within an encoder architecture, there remain some limitations that we must acknowledge, and which can provide a basis for future research.

First, our method’s effectiveness was predominantly demonstrated on environments with relatively simple underlying dynamics. In these environments, the disentanglement process was easier to achieve due to the limited complexity of internal dynamics present. As we begin to transfer our approach to more complex environments characterized by more extensive internal dynamics, there can arise two problems; The first being that the separation of controllable from uncontrollable features may not be as clear-cut in more complex MDPs, but can be more on a spectrum, complicating the fundamental differences between a state-only and a state-action forward predictor. The second being that interpretability will be harder to enforce when there are a large number of underlying factors of variation. As distinct seeds can give different orderings and signs of the neurons in the final layer of the encoder, identifying a factor of variation can become exponentially harder for more complex environments.

Lastly, while our work showed that an action-conditioned forward predictor could be preferred over an inverse predictor in some environments for isolating controllable features, it may not hold for all scenarios. The inherent properties of different environments might show a necessity of using different predictors. Consequently, there could very well be MDPs where our current approach might not provide the same level of disentanglement showed in the MDPs used in this paper.

Despite these limitations, we believe our work provides a strong foundation upon which future research can build and further extend the possibilities of achieving a highly interpretable latent representation through disentanglement of controllable and uncontrollable features.

VII Conclusion and Future Work

We have shown the possibility of disentangling controllable and uncontrollable features in an encoder architecture, strongly increasing the interpretability of the latent representation while also showing the potential use of this for downstream learning and planning, even in a single latent partition. This disentanglement of controllable and uncontrollable features in the latent representation of high-dimensional MDPs was achieved by propagating an action-conditioned forward prediction loss and a state-only forward prediction loss through distinct sections of the latent representation. Additionally, a contrastive loss and an adversarial loss were used to respectively avoid collapse and further disentangle the latent representation. Furthermore, we showed that an action-conditioned forward predictor can, in some environments, be preferred as compared to an inverse predictor in terms of isolating controllable features in the representation. Finally, by employing forward prediction in latent space, we were able to successfully run a planning algorithm while leveraging the properties of the environment. In particular, the disentanglement of controllable and uncontrollable features allowed us to keep zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT frozen regardless of planning depth in the context of a distribution of randomly generated mazes, i.e. we only do forward prediction in zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT.

Future work could focus on gradually transferring our notion of disentanglement and interpretability to environments with more extensive underlying internal dynamics. Further work could also look at the ordering of the latent dimensions, as a latent representation is often arbitrarily ordered. This means that distinct seeds will lead to a different ordering and sign of the neurons in the final layer of the encoder. For example, if seed one would give agent position +x and +y for neurons 1 and 2 respectively, then seed two could give agent position -y and +x to the same neurons. As we are additionally using a contrastive loss while learning our representation, these results are compliant with the theory that a contrastive loss can recover the original latent information up to an orthogonal linear transformation [35].

Refer to caption
(a) Planning depth 3
Refer to caption
(b) Planning depth 9
Figure 6: Visualization of the latent representation through an actual planning iteration utilizing a planning depth of 3 (a) and a planning depth of 9 (b), with the controllable representation zc2superscript𝑧𝑐superscript2z^{c}\in\mathbb{R}^{2}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (left), the uncontrollable representation zu6×6superscript𝑧𝑢superscript66z^{u}\in\mathbb{R}^{6\times 6}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 6 × 6 end_POSTSUPERSCRIPT (middle) that is kept static throughout planning depth and the original pixel input st48×48subscript𝑠𝑡superscript4848s_{t}\in\mathbb{R}^{48\times 48}italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 48 × 48 end_POSTSUPERSCRIPT (right). The translucent red dots represent every possible encoded state in the random maze environment, the full blue dot represents the current encoded state, the red dots represent intermediate encoded states estimated by planning and the green dot represents the final predicted state as chosen by the planning algorithm, consistent with its depth.

Certain benefits can be obtained as well with a particular design of the encoder architecture, as we have done in this paper using estimates of the necessary dimensions of zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT and zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT for the different MDP environments. This can be seen as an inductive bias to aid disentanglement, as mentioned by [36]. Succeeding work could also focus on finding more algorithmic benefits of this disentanglement of controllable/uncontrollable features in more complex environments. For example, in the context of safety, a disentangled interpretable representation could allow incorporating latent state constraints in a planning algorithm. Lastly, as discussed by [13, 36], an interesting venue could be to further investigate the trade-off between interpretability and downstream performance. This is due to the fact that black-box representations such as Figure 3(c) still seem to have excellent downstream performance with DDQN, where for the task of maze navigation, a human would perform substantially better using the representation portrayed in Figure 3(a) as compared to using the representation in Figure 3(c).

References

  • [1] R. Bellman, “A markovian decision process,” in Journal of Mathematics and Mechanics, vol. 6, no. 5, 1957.
  • [2] R. Jonschkowski and O. Brock, “Learning state representations with robotic priors,” Autonomous Robots, vol. 39, no. 3, 2015.
  • [3] M. Jaderberg, V. Mnih, W. M. Czarnecki, T. Schaul, J. Z. Leibo, D. Silver, and K. Kavukcuoglu, “Reinforcement Learning with Unsupervised Auxiliary Tasks,” in International Conference on Learning Representations, ICLR, 2017.
  • [4] M. Laskin, A. Srinivas, and P. Abbeel, “CURL: Contrastive unsupervised representations for reinforcement learning,” in 37th International Conference on Machine Learning, ICML, 2020.
  • [5] K. H. Lee, I. Fischer, A. Z. Liu, Y. Guo, H. Lee, J. Canny, and S. Guadarrama, “Predictive information accelerates learning in RL,” in Advances in Neural Information Processing Systems, NIPS, 2020.
  • [6] D. Yarats, A. Zhang, I. Kostrikov, B. Amos, J. Pineau, and R. Fergus, “Improving Sample Efficiency in Model-Free Reinforcement Learning from Images,” in Proceedings of the AAAI Conference on Artificial Intelligence, AAAI, 2021.
  • [7] M. Schwarzer, A. Anand, R. Goel, R. D. Hjelm, A. Courville, and P. Bachman, “Data-Efficient Reinforcement Learning with Self-Predictive Representations,” in International Conference on Learning Representations, ICLR, 2021.
  • [8] I. Kostrikov, D. Yarats, and R. Fergus, “Image Augmentation Is All You Need: Regularizing Deep Reinforcement Learning from Pixels,” in International Conference on Learning Representations, ICLR, 2021.
  • [9] V. Francois-Lavet, Y. Bengio, D. Precup, and J. Pineau, “Combined Reinforcement Learning via Abstract Representations,” in Proceedings of the AAAI Conference on Artificial Intelligence, AAAI, 2019.
  • [10] V. Thomas, J. Pondard, E. Bengio, M. Sarfati, P. Beaudoin, M.-J. Meurs, J. Pineau, D. Precup, and Y. Bengio, “Independently Controllable Factors,” arXiv preprint arXiv:1708.01289, 2017.
  • [11] T. Kipf, E. van der Pol, and M. Welling, “Contrastive Learning of Structured World Models,” in International Conference on Learning Representations, ICLR, 2020.
  • [12] K. Ahuja, J. Hartford, and Y. Bengio, “Weakly supervised representation learning with sparse perturbations,” in Advances in Neural Information Processing Systems, NIPS, 2022.
  • [13] C. Glanois, P. Weng, M. Zimmer, D. Li, T. Yang, J. Hao, and W. Liu, “A survey on interpretable reinforcement learning,” arXiv preprint arXiv:2112.13112, 2021.
  • [14] D. P. Kingma and M. Welling, “Auto-Encoding Variational Bayes,” in International Conference on Learning Representations, ICLR, 2014.
  • [15] I. Higgins, A. Pal, A. A. Rusu, L. Matthey, C. P. Burgess, A. Pritzel, M. Botvinick, C. Blundell, and A. Lerchner, “DARLA: Improving Zero-Shot Transfer in Reinforcement Learning,” in International Conference on Machine Learning, ICML, 2017.
  • [16] C. Gelada, S. Kumar, J. Buckman, O. Nachum, and M. G. Bellemare, “DeepMDP: Learning continuous latent space models for representation learning,” in International Conference on Machine Learning, ICML, 2019.
  • [17] D. Hafner, T. Lillicrap, I. Fischer, R. Villegas, D. Ha, H. Lee, and J. Davidson, “Learning latent dynamics for planning from pixels,” in International Conference on Machine Learning, ICML, 2019.
  • [18] D. Hafner, T. Lillicrap, M. Norouzi, and J. Ba, “Mastering Atari with Discrete World Models,” in International Conference on Learning Representations, ICLR, 2021.
  • [19] A. Laversanne-Finot, A. Pere, and P.-Y. Oudeyer, “Curiosity driven exploration of learned disentangled goal spaces,” in Proceedings of The 2nd Conference on Robot Learning.   PMLR, 2018.
  • [20] D. Pathak, P. Agrawal, A. A. Efros, and T. Darrell, “Curiosity-driven Exploration by Self-supervised Prediction,” in IEEE Conference on Computer Vision and Pattern Recognition Workshops, CVPRW, 2017.
  • [21] A. P. Badia, P. Sprechmann, A. Vitvitskyi, D. Guo, B. Piot, S. Kapturowski, O. Tieleman, M. Arjovsky, A. Pritzel, A. Bolt, and C. Blundell, “Never Give Up: Learning Directed Exploration Strategies,” in International Conference on Learning Representations, ICLR, 2020.
  • [22] Y. Efroni, D. Misra, A. Krishnamurthy, A. Agarwal, and J. Langford, “Provable RL with exogenous distractors via multistep inverse dynamics,” in International Conference on Machine Learning, ICML, 2021.
  • [23] A. Lamb, R. Islam, Y. Efroni, A. Didolkar, D. Misra, D. Foster, L. Molu, R. Chari, A. Krishnamurthy, and J. Langford, “Guaranteed discovery of controllable latent states with multi-step inverse models,” arXiv preprint arXiv:2207.08229, 2022.
  • [24] D. Bertoin and E. Rachelson, “Disentanglement by cyclic reconstruction,” in IEEE Transactions on Neural Networks and Learning Systems, 2022.
  • [25] X. Fu, G. Yang, P. Agrawal, and T. Jaakkola, “Learning task informed abstractions,” in International Conference on Machine Learning, ICML, 2021.
  • [26] T. Wang, S. Du, A. Torralba, P. Isola, A. Zhang, and Y. Tian, “Denoised MDPs: Learning world models better than the world itself,” in Proceedings of the 39th International Conference on Machine Learning, PMLR, 2022.
  • [27] E. van der Pol, D. Worrall, H. van Hoof, F. Oliehoek, and M. Welling, “Mdp homomorphic networks: Group symmetries in reinforcement learning,” in Advances in Neural Information Processing Systems, NIPS, 2020.
  • [28] Y. Ganin, E. Ustinova, H. Ajakan, P. Germain, H. Larochelle, F. Laviolette, M. Marchand, V. Lempitsky, U. Dogan, M. Kloft, F. Orabona, T. Tommasi, and a. Ganin, “Domain-Adversarial Training of Neural Networks,” in Journal of Machine Learning Research, vol. 17, 2016.
  • [29] I. J. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, and Y. Bengio, “Generative Adversarial Networks,” in Advances in Neural Information Processing Systems, NIPS, 2014.
  • [30] V. Mnih, K. Kavukcuoglu, D. Silver, A. A. Rusu, J. Veness, M. G. Bellemare, A. Graves, M. Riedmiller, A. K. Fidjeland, G. Ostrovski, S. Petersen, C. Beattie, A. Sadik, I. Antonoglou, H. King, D. Kumaran, D. Wierstra, S. Legg, and D. Hassabis, “Human-level control through deep reinforcement learning,” Nature, vol. 518, 2015.
  • [31] H. van Hasselt, A. Guez, and D. Silver, “Deep Reinforcement Learning with Double Q-learning,” in Proceedings of the AAAI Conference on Artificial Intelligence, AAAI, 2016.
  • [32] I. Higgins, D. Amos, D. Pfau, S. Racanière, L. Matthey, D. J. Rezende, and A. Lerchner, “Towards a definition of disentangled representations,” arXiv preprint arXiv:1812.02230, 2018.
  • [33] Y. Tassa, Y. Doron, A. Muldal, T. Erez, Y. Li, D. d. L. Casas, D. Budden, A. Abdolmaleki, J. Merel, A. Lefrancq, T. Lillicrap, and M. Riedmiller, “DeepMind Control Suite,” arXiv preprint arXiv:1801.00690, 2018.
  • [34] J. Oh, S. Singh, and H. Lee, “Value Prediction Network,” in Advances in Neural Information Processing Systems, NIPS, 2017.
  • [35] R. S. Zimmermann, Y. Sharma, S. Schneider, M. Bethge, and W. Brendel, “Contrastive learning inverts the data generating process,” in Proceedings of the 38th International Conference on Machine Learning, ser. Proceedings of Machine Learning Research, M. Meila and T. Zhang, Eds., vol. 139.   PMLR, 18–24 Jul 2021, pp. 12 979–12 990.
  • [36] F. Locatello, S. Bauer, M. Lucic, G. Raetsch, S. Gelly, B. Schölkopf, and O. Bachem, “Challenging common assumptions in the unsupervised learning of disentangled representations,” in Proceedings of the 36th International Conference on Machine Learning, PMLR, 2019.
  • [37] D. P. Kingma and J. Ba, “Adam: A Method for Stochastic Optimization,” in International Conference on Learning Representations, ICLR, 2015.

Appendix A Additional Material

A-A Ablation of the contrastive scalar

Without using a pixel reconstruction loss, the contrastive loss Hsubscript𝐻\mathcal{L}_{H}caligraphic_L start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT is crucial in avoiding the trivial solution for any latent forward predictor [9, 16]. The contrastive scalar that regulates the Hsubscript𝐻\mathcal{L}_{H}caligraphic_L start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT however remains the most influential hyperparameter. When Cdsubscript𝐶𝑑C_{d}italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT is chosen too high, the representation remains in a compact cluster. On the other hand, when Cdsubscript𝐶𝑑C_{d}italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT is chosen too low, unnecessary inter-sample distances are formed to enforce large individual latent distances. Two ablations of the contrastive scalar Cdsubscript𝐶𝑑C_{d}italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT are shown in Fig. 7.

Refer to caption
(a) Cd=13subscript𝐶𝑑13C_{d}=13italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT = 13
Refer to caption
(b) Cd=3subscript𝐶𝑑3C_{d}=3italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT = 3
Figure 7: Ablation of the hyperparameter Cdsubscript𝐶𝑑C_{d}italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT, where a higher value of Cdsubscript𝐶𝑑C_{d}italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT enforces less entropy in the representation, while a lower value of Cdsubscript𝐶𝑑C_{d}italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT especially pushes the controllable features zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT towards shapes that ensure large distances between samples. In both figures, the left column is zc2superscript𝑧𝑐superscript2z^{c}\in\mathbb{R}^{2}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, the middle column is zu6×6superscript𝑧𝑢superscript66z^{u}\in\mathbb{R}^{6\times 6}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 6 × 6 end_POSTSUPERSCRIPT and the right column is the input state 48×48absentsuperscript4848\in\mathbb{R}^{48\times 48}∈ blackboard_R start_POSTSUPERSCRIPT 48 × 48 end_POSTSUPERSCRIPT.

A-B Ablation of learning rates

We show experiments in Fig. 8 and Fig. 9 where we employ different learning rates for the encoder and the action-conditioned forward predictor, respectively.

A-C Ablation of the detachment of zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT and ablation of the residual prediction

As seen in the main paper in Figure 2, we detach the uncontrollable representation zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT from csubscript𝑐\mathcal{L}_{c}caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT as we do not want controllable features to be present in zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT. We can see in Figure 10 that updating zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT with csubscript𝑐\mathcal{L}_{c}caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT leads to slightly better transition predictions in zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT, but also results in a less interpretable encoding of zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT. Furthermore, we can also see in Figure 10 that, when using normal forward predictions instead of residual forward predictions, we lose almost all of our interpretable structure in zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT.

Refer to caption
(a) Encoder learning rate of 2e-4
Refer to caption
(b) Encoder learning rate of 2e-6
Figure 8: Ablation of the learning rates for the encoder, where a too low learning rate causes collapse of zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT and a too high learning rate causes distortions in the uncontrollable features zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT. In both figures, the left column is zc2superscript𝑧𝑐superscript2z^{c}\in\mathbb{R}^{2}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, the middle column is zu6×6superscript𝑧𝑢superscript66z^{u}\in\mathbb{R}^{6\times 6}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 6 × 6 end_POSTSUPERSCRIPT and the right column is the input state 48×48absentsuperscript4848\in\mathbb{R}^{48\times 48}∈ blackboard_R start_POSTSUPERSCRIPT 48 × 48 end_POSTSUPERSCRIPT.
Refer to caption
(a) Tcsubscript𝑇𝑐T_{c}italic_T start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT learning rate of 4e-4
Refer to caption
(b) Tcsubscript𝑇𝑐T_{c}italic_T start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT learning rate of 4e-6
Figure 9: Ablation of the learning rates for the action-conditioned forward predictor. A too high learning rate will cause the controllable representation to lose structure, while a low learning rate retains structure but does not learn strong transition dynamics. In both figures, the left column is zc2superscript𝑧𝑐superscript2z^{c}\in\mathbb{R}^{2}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, the middle column is zu6×6superscript𝑧𝑢superscript66z^{u}\in\mathbb{R}^{6\times 6}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 6 × 6 end_POSTSUPERSCRIPT and the right column is the input state 48×48absentsuperscript4848\in\mathbb{R}^{48\times 48}∈ blackboard_R start_POSTSUPERSCRIPT 48 × 48 end_POSTSUPERSCRIPT.
Refer to caption
(a) No detachment of zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT
Refer to caption
(b) No residual predictions of zt+1csubscriptsuperscript𝑧𝑐𝑡1z^{c}_{t+1}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT and zt+1usubscriptsuperscript𝑧𝑢𝑡1z^{u}_{t+1}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT
Figure 10: In both figures, the left column is zc2superscript𝑧𝑐superscript2z^{c}\in\mathbb{R}^{2}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, the middle column is zu6×6superscript𝑧𝑢superscript66z^{u}\in\mathbb{R}^{6\times 6}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 6 × 6 end_POSTSUPERSCRIPT and the right column is the input state 48×48absentsuperscript4848\in\mathbb{R}^{48\times 48}∈ blackboard_R start_POSTSUPERSCRIPT 48 × 48 end_POSTSUPERSCRIPT.

A-D Ablation of the entropy loss H2subscript𝐻2\mathcal{L}_{H2}caligraphic_L start_POSTSUBSCRIPT italic_H 2 end_POSTSUBSCRIPT

As the amount of possible encoded maze architectures goes to infinity due to the procedural generation, a collapse in the controllable features zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT can be noticed when using only H1subscript𝐻1\mathcal{L}_{H1}caligraphic_L start_POSTSUBSCRIPT italic_H 1 end_POSTSUBSCRIPT as the contrastive loss (see Fig. 11). On the other hand, when using only H2subscript𝐻2\mathcal{L}_{H2}caligraphic_L start_POSTSUBSCRIPT italic_H 2 end_POSTSUBSCRIPT as the contrastive loss, there is no more clear distinction in the uncontrollable representation zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT. The best results were obtained using a combination of the aforementioned losses.

Refer to caption
(a) H=H1subscript𝐻subscript𝐻1\mathcal{L}_{H}=\mathcal{L}_{H1}caligraphic_L start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT = caligraphic_L start_POSTSUBSCRIPT italic_H 1 end_POSTSUBSCRIPT
Refer to caption
(b) H=H2subscript𝐻subscript𝐻2\mathcal{L}_{H}=\mathcal{L}_{H2}caligraphic_L start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT = caligraphic_L start_POSTSUBSCRIPT italic_H 2 end_POSTSUBSCRIPT
Figure 11: In both figures, the left column is zc2superscript𝑧𝑐superscript2z^{c}\in\mathbb{R}^{2}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, the middle column is zu6×6superscript𝑧𝑢superscript66z^{u}\in\mathbb{R}^{6\times 6}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 6 × 6 end_POSTSUPERSCRIPT and the right column is the input state 48×48absentsuperscript4848\in\mathbb{R}^{48\times 48}∈ blackboard_R start_POSTSUPERSCRIPT 48 × 48 end_POSTSUPERSCRIPT.

A-E Planning

We use a planning algorithm derived from [34, 9], where we employ d-step planning as:

Q^d((z^tc,zu),a)={P((z^tc,zu),a;θr)+Γ((z^tc,zu),a;θγ)maxa𝒜*Q^d1((z^t+1c,zu),a), if d>0Q((z^tc,zu),a;θ), if d=0\hat{Q}^{d}((\hat{z}^{c}_{t},z^{u}),a)=\left\{\begin{array}[]{ll}P((\hat{z}^{c% }_{t},z^{u}),a;\theta_{r})+\Gamma((\hat{z}^{c}_{t},z^{u}),a;\theta_{\gamma})\ % \underset{a^{\prime}\in\mathcal{A}^{*}}{\operatorname{max}}\ \hat{Q}^{d-1}(\\ (\hat{z}^{c}_{t+1},z^{u}),a^{\prime}),\hskip 56.9055pt\text{ if }d>0\\ Q((\hat{z}^{c}_{t},z^{u}),a;\theta),\hskip 51.21495pt\text{ if }d=0\end{array}\right.over^ start_ARG roman_Q end_ARG start_POSTSUPERSCRIPT roman_d end_POSTSUPERSCRIPT ( ( over^ start_ARG roman_z end_ARG start_POSTSUPERSCRIPT roman_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT , roman_z start_POSTSUPERSCRIPT roman_u end_POSTSUPERSCRIPT ) , roman_a ) = { start_ARRAY start_ROW start_CELL roman_P ( ( over^ start_ARG roman_z end_ARG start_POSTSUPERSCRIPT roman_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT , roman_z start_POSTSUPERSCRIPT roman_u end_POSTSUPERSCRIPT ) , roman_a ; italic_θ start_POSTSUBSCRIPT roman_r end_POSTSUBSCRIPT ) + roman_Γ ( ( over^ start_ARG roman_z end_ARG start_POSTSUPERSCRIPT roman_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT , roman_z start_POSTSUPERSCRIPT roman_u end_POSTSUPERSCRIPT ) , roman_a ; italic_θ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ) start_UNDERACCENT roman_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_A start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_UNDERACCENT start_ARG roman_max end_ARG over^ start_ARG roman_Q end_ARG start_POSTSUPERSCRIPT roman_d - 1 end_POSTSUPERSCRIPT ( end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL ( over^ start_ARG roman_z end_ARG start_POSTSUPERSCRIPT roman_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_t + 1 end_POSTSUBSCRIPT , roman_z start_POSTSUPERSCRIPT roman_u end_POSTSUPERSCRIPT ) , roman_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , if roman_d > 0 end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL roman_Q ( ( over^ start_ARG roman_z end_ARG start_POSTSUPERSCRIPT roman_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT , roman_z start_POSTSUPERSCRIPT roman_u end_POSTSUPERSCRIPT ) , roman_a ; italic_θ ) , if roman_d = 0 end_CELL start_CELL end_CELL end_ROW end_ARRAY (9)
QplanD((z^tc,zu),a)=d=0DQ^d((z^tc,zu),a)superscriptsubscript𝑄𝑝𝑙𝑎𝑛𝐷subscriptsuperscript^𝑧𝑐𝑡superscript𝑧𝑢𝑎superscriptsubscript𝑑0𝐷superscript^𝑄𝑑subscriptsuperscript^𝑧𝑐𝑡superscript𝑧𝑢𝑎Q_{plan}^{D}((\hat{z}^{c}_{t},z^{u}),a)=\sum_{d=0}^{D}\hat{Q}^{d}((\hat{z}^{c}% _{t},z^{u}),a)italic_Q start_POSTSUBSCRIPT italic_p italic_l italic_a italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ( ( over^ start_ARG italic_z end_ARG start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ) , italic_a ) = ∑ start_POSTSUBSCRIPT italic_d = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT over^ start_ARG italic_Q end_ARG start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ( ( over^ start_ARG italic_z end_ARG start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ) , italic_a ) (10)

Where P(st,a;θr):𝒵×𝒜:𝑃subscript𝑠𝑡𝑎subscript𝜃𝑟𝒵𝒜P(s_{t},a;\theta_{r}):\mathcal{Z}\times\mathcal{A}\rightarrow\mathcal{R}italic_P ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a ; italic_θ start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ) : caligraphic_Z × caligraphic_A → caligraphic_R represents the reward predictor and Γ(s,a;θγ):𝒵×𝒜γ:Γ𝑠𝑎subscript𝜃𝛾𝒵𝒜𝛾\Gamma(s,a;\theta_{\gamma}):\mathcal{Z}\times\mathcal{A}\rightarrow\gammaroman_Γ ( italic_s , italic_a ; italic_θ start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT ) : caligraphic_Z × caligraphic_A → italic_γ represents the discount value predictor. The action is chosen by taking the argmax of QplanD((z^tc,zu),a)superscriptsubscript𝑄𝑝𝑙𝑎𝑛𝐷subscriptsuperscript^𝑧𝑐𝑡superscript𝑧𝑢𝑎Q_{plan}^{D}((\hat{z}^{c}_{t},z^{u}),a)italic_Q start_POSTSUBSCRIPT italic_p italic_l italic_a italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ( ( over^ start_ARG italic_z end_ARG start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ) , italic_a ). Note in the results from Section 5.3, we are only forward predicting in the controllable latent space zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT, and that zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT remains a fixed value regardless of planning depth. This is possible by making use of the prior knowledge of the maze environments together with a disentangled controllable and uncontrollable latent representation.

A-F Inverse Prediction

A common single-step inverse prediction is defined as:

a^t=f(st,st+1)subscript^𝑎𝑡𝑓subscript𝑠𝑡subscript𝑠𝑡1\hat{a}_{t}=f(s_{t},s_{t+1})over^ start_ARG italic_a end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_f ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ) (11)

where a^tsubscript^𝑎𝑡\hat{a}_{t}over^ start_ARG italic_a end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is the predicted action and f(st,st+1)𝑓subscript𝑠𝑡subscript𝑠𝑡1f(s_{t},s_{t+1})italic_f ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ) represents an arbitrarily structured function. In the random maze environment, we use a parameterized inverse predictor which predicts in latent space:

a^t=I(ztc,zt+1c,ztu,zt+1u;θinv)subscript^𝑎𝑡𝐼subscriptsuperscript𝑧𝑐𝑡subscriptsuperscript𝑧𝑐𝑡1subscriptsuperscript𝑧𝑢𝑡subscriptsuperscript𝑧𝑢𝑡1subscript𝜃𝑖𝑛𝑣\hat{a}_{t}=I(z^{c}_{t},z^{c}_{t+1},z^{u}_{t},z^{u}_{t+1};\theta_{inv})over^ start_ARG italic_a end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_I ( italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT , italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_i italic_n italic_v end_POSTSUBSCRIPT ) (12)

Where I(;θinv):𝒵𝒜:𝐼subscript𝜃𝑖𝑛𝑣𝒵𝒜I(\cdot;\theta_{inv})\in\mathcal{I}:\mathcal{Z}\rightarrow\mathcal{A}italic_I ( ⋅ ; italic_θ start_POSTSUBSCRIPT italic_i italic_n italic_v end_POSTSUBSCRIPT ) ∈ caligraphic_I : caligraphic_Z → caligraphic_A is a parameterized inverse prediction function. Since we have 4 actions, we use the 4-dimensional logit output a^tsubscript^𝑎𝑡\hat{a}_{t}over^ start_ARG italic_a end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT to calculate the inverse prediction loss invsubscript𝑖𝑛𝑣\mathcal{L}_{inv}caligraphic_L start_POSTSUBSCRIPT italic_i italic_n italic_v end_POSTSUBSCRIPT as:

S(a^i)=exp(a^i)j=1naexp(a^j),inv=i=1naailog(S(a^i))formulae-sequence𝑆subscript^𝑎𝑖expsubscript^𝑎𝑖superscriptsubscript𝑗1subscript𝑛𝑎expsubscript^𝑎𝑗subscript𝑖𝑛𝑣superscriptsubscript𝑖1subscript𝑛𝑎subscript𝑎𝑖𝑆subscript^𝑎𝑖S(\hat{a}_{i})=\frac{\text{exp}({\hat{a}_{i}})}{\sum_{j=1}^{n_{a}}\text{exp}({% \hat{a}_{j}})},\quad\mathcal{L}_{inv}=-\sum_{i=1}^{n_{a}}a_{i}\log(S(\hat{a}_{% i}))italic_S ( over^ start_ARG italic_a end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = divide start_ARG exp ( over^ start_ARG italic_a end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT end_POSTSUPERSCRIPT exp ( over^ start_ARG italic_a end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) end_ARG , caligraphic_L start_POSTSUBSCRIPT italic_i italic_n italic_v end_POSTSUBSCRIPT = - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( italic_S ( over^ start_ARG italic_a end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) (13)

Here, nasubscript𝑛𝑎n_{a}italic_n start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT is the number of actions, S(a^i)𝑆subscript^𝑎𝑖S(\hat{a}_{i})italic_S ( over^ start_ARG italic_a end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) represents the softmax operator and aisubscript𝑎𝑖a_{i}italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the actual action, given as a 0 or 1 truth label. This is more commonly known as the Cross-Entropy loss computation.

A-G Reconstruction

We run an additional ablation on the four mazes environment, where the contrastive loss Hsubscript𝐻\mathcal{L}_{H}caligraphic_L start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT is replaced with a pixel reconstruction loss. The resulting representation comparison can be seen in Fig. 12.

Refer to caption
(a) Contrastive loss
Refer to caption
(b) Reconstruction loss
Figure 12: Visualization in a maze environment of the disentanglement of the controllable latent zc2superscript𝑧𝑐superscript2z^{c}\in\mathbb{R}^{2}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT on the horizontal axes, and the uncontrollable latent zu1superscript𝑧𝑢superscript1z^{u}\in\mathbb{R}^{1}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT on the vertical axis, given for all states in the four maze environments shown in four different colors. The representation is trained on high-dimensional tuples (st,at,rt,st+1)subscript𝑠𝑡subscript𝑎𝑡subscript𝑟𝑡subscript𝑠𝑡1(s_{t},a_{t},r_{t},s_{t+1})( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ), sampled from a replay buffer \mathcal{B}caligraphic_B, gathered from random trajectories in the four maze environments. All possible states are encoded with zt=f(st;θencz_{t}=f(s_{t};\theta_{enc}italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_f ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_e italic_n italic_c end_POSTSUBSCRIPT) and plotted in (a) and (b) together with the transition prediction for each possible action. In (a), a clear disentanglement between the controllable agent’s position and the uncontrollable wall architecture is portrayed. In (b), it seems that a reconstruction loss groups observations with similar pixel inputs together, and thus allows the forward predictors to ’collapse’ to unit matrices, decreasing representation quality.

A-H T-SNE

We conduct an additional experiment in the random maze environment where we use a latent dimension of 32, partition it in half to form zc16superscript𝑧𝑐superscript16z^{c}\in\mathbb{R}^{16}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 16 end_POSTSUPERSCRIPT and zu16superscript𝑧𝑢superscript16z^{u}\in\mathbb{R}^{16}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 16 end_POSTSUPERSCRIPT and show the a T-SNE visualization of 6 different trajectories in random mazes in Fig. 13. Note that, because the trajectories are random, only a subpart of the possible agent positions in every random maze is present.

Refer to caption
(a) T-SNE of zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT and zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT
Refer to caption
(b) 6 random mazes
Figure 13: Ablation of a dimensionality increase in our random maze environment. Here, the total latent space is a 32-dimensional MLP output, where zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT and zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT are both 16-dimensional. in (a), 6 random trajectories are plotted using T-SNE (perplexity=20) in different colors for both zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT and zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT, where zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT remains similar across a trajectory (same wall architecture), and zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT differs across the trajectory (different agent positions). In (b), a collection of the random mazes are shown from which the random trajectories have been taken.

Appendix B Experiment details

The Pytorch framework was used for all experiments, as well as the Adam optimizer [37]. We employ a batch size of 32 tuples (st,at,rt,st+1)subscript𝑠𝑡subscript𝑎𝑡subscript𝑟𝑡subscript𝑠𝑡1(s_{t},a_{t},r_{t},s_{t+1})( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_s start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ) for every update. In all experiments, we detach ztcsubscriptsuperscript𝑧𝑐𝑡z^{c}_{t}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT in the calculation of csubscript𝑐\mathcal{L}_{c}caligraphic_L start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT, as it allowed us to use a larger learning rate for Tcsubscript𝑇𝑐T_{c}italic_T start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT without causing instabilities.

Simple Maze

The replay buffer \mathcal{B}caligraphic_B is filled with 5k transitions from each of the four wall architectures. The transitions are collected by the agent following a random policy. The learning rate for the encoder is 51055superscript1055\cdot 10^{-5}5 ⋅ 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT, for the action-conditioned forward predictor 11031superscript1031\cdot 10^{-3}1 ⋅ 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT and for the uncontrollable forward predictor 51055superscript1055\cdot 10^{-5}5 ⋅ 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT. The contrastive scalar Cdsubscript𝐶𝑑C_{d}italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT is set to 15.

Catcher

The replay buffer \mathcal{B}caligraphic_B is filled with 25k transitions. The transitions are collected by the agent following a random policy. A new random maze is created after 50 time steps or when the reward is acquired. The learning rate for the encoder is 21052superscript1052\cdot 10^{-5}2 ⋅ 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT, for the action-conditioned forward predictor 41054superscript1054\cdot 10^{-5}4 ⋅ 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT and for the uncontrollable forward predictor 11051superscript1051\cdot 10^{-5}1 ⋅ 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT. When using the adversarial loss, we use a learning rate of 11031superscript1031\cdot 10^{-3}1 ⋅ 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT for the adversarial predictor. The contrastive scalar Cdsubscript𝐶𝑑C_{d}italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT is set to 5.

Random Maze

The replay buffer \mathcal{B}caligraphic_B is filled with 50k transitions, representing around 1000 maze architectures. The transitions are collected by the agent following a random policy. The learning rates used are equal to those of the catcher environment; for the encoder 21052superscript1052\cdot 10^{-5}2 ⋅ 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT, for the action-conditioned forward predictor 41054superscript1054\cdot 10^{-5}4 ⋅ 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT and for the uncontrollable forward predictor 11051superscript1051\cdot 10^{-5}1 ⋅ 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT. After freezing the encoder, we train the action-conditioned forward predictor for an additional 250k iterations on the same 50k transitions in the buffer \mathcal{B}caligraphic_B. For updating the Q-network with DDQN, we use a learning rate of 11041superscript1041\cdot 10^{-4}1 ⋅ 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT, and a τ𝜏\tauitalic_τ of 0.02. The contrastive scalar Cdsubscript𝐶𝑑C_{d}italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT is set to 13. When using planning, we employ a learning rate of 51055superscript1055\cdot 10^{-5}5 ⋅ 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT for the reward and discount prediction networks.

Contrastive Loss

For the catcher and random maze environment, given that zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT is 1 or 2-dimensional, and zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT is a 36-dimensional feature map, we alleviate dimensional mismatch when calculating the contrastive loss in Equation 4 in the main paper. This is done by taking a random subset of 15 out of 36 feature values in zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT for every batch.

Appendix C Network Architecture

We use the same base encoder for all experiments, made up of 2 convolutional layers of 32 channels each, with a kernel size of 3 and stride 2, except for the final layer which has stride 1. Both convolutional layers have a Rectified Linear Unit (ReLU) nonlinear activation.

In the quadruple maze environment, the output of the base convolutional encoder is flattened and used as an input to a single linear layer with 3 outputs (zc+zusuperscript𝑧𝑐superscript𝑧𝑢z^{c}+z^{u}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT + italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT) and a hyperbolic tangent (tanh) activation function.

In the catcher and random maze environments, we use the following encoder head to extract the uncontrollable features; the base convolutional layers are followed by a single convolutional layer with 32 channels, a kernel size of 4 and a stride of 1. This layer is followed by a ReLU activation function and an AveragePool layer with an output size of 6. For the controllable features, we flatten the output of the base convolutional encoder and use this as an input to a linear layer with 200 neurons and a tanh activation function. This layer is followed by another linear layer with ncsubscript𝑛𝑐n_{c}italic_n start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT neurons and a tanh activation function.

The transition and prediction models all have the same structure, with linear layers of 32-128-128-32-x neurons where x is the output dimension in line with the predicted feature’s dimension. The linear layers all have tanh activation functions except for the final output. Only the action-conditioned transition predictor of the random maze environment has larger layer sizes, with linear layers of 128-512-512-128-2, to account for slightly more complicated transitions. The DQN network used is of size 128-512-512-128-4, with an output value corresponding to each possible action.

C-A Quadruple Maze Progression

Refer to caption
(a) 1k iterations
Refer to caption
(b) 2k iterations
Refer to caption
(c) 5k iterations
Figure 14: Progression of the separation of the controllable zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT (x and y-axis) and uncontrollable zusuperscript𝑧𝑢z^{u}italic_z start_POSTSUPERSCRIPT italic_u end_POSTSUPERSCRIPT (z-axis) features in the maze environment.

C-B Catcher

0pt Refer to caption

(a) zc1superscript𝑧𝑐superscript1z^{c}\in\mathbb{R}^{1}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT without advsubscript𝑎𝑑𝑣\mathcal{L}_{adv}caligraphic_L start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT

0pt Refer to caption

(b) zc2superscript𝑧𝑐superscript2z^{c}\in\mathbb{R}^{2}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT without advsubscript𝑎𝑑𝑣\mathcal{L}_{adv}caligraphic_L start_POSTSUBSCRIPT italic_a italic_d italic_v end_POSTSUBSCRIPT
Figure 15: Comparison of training the representation for the catcher environment with either 1 or 2-dimensions for the controllable representation zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT. When using more dimensions for zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT than needed, it can be observed that some information of the ball position can be present in zcsuperscript𝑧𝑐z^{c}italic_z start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT.