-
Learning Cognitive Maps from Transformer Representations for Efficient Planning in Partially Observed Environments
Authors:
Antoine Dedieu,
Wolfgang Lehrach,
Guangyao Zhou,
Dileep George,
Miguel Lázaro-Gredilla
Abstract:
Despite their stellar performance on a wide range of tasks, including in-context tasks only revealed during inference, vanilla transformers and variants trained for next-token predictions (a) do not learn an explicit world model of their environment which can be flexibly queried and (b) cannot be used for planning or navigation. In this paper, we consider partially observed environments (POEs), wh…
▽ More
Despite their stellar performance on a wide range of tasks, including in-context tasks only revealed during inference, vanilla transformers and variants trained for next-token predictions (a) do not learn an explicit world model of their environment which can be flexibly queried and (b) cannot be used for planning or navigation. In this paper, we consider partially observed environments (POEs), where an agent receives perceptually aliased observations as it navigates, which makes path planning hard. We introduce a transformer with (multiple) discrete bottleneck(s), TDB, whose latent codes learn a compressed representation of the history of observations and actions. After training a TDB to predict the future observation(s) given the history, we extract interpretable cognitive maps of the environment from its active bottleneck(s) indices. These maps are then paired with an external solver to solve (constrained) path planning problems. First, we show that a TDB trained on POEs (a) retains the near perfect predictive performance of a vanilla transformer or an LSTM while (b) solving shortest path problems exponentially faster. Second, a TDB extracts interpretable representations from text datasets, while reaching higher in-context accuracy than vanilla sequence models. Finally, in new POEs, a TDB (a) reaches near-perfect in-context accuracy, (b) learns accurate in-context cognitive maps (c) solves in-context path planning problems.
△ Less
Submitted 11 January, 2024;
originally announced January 2024.
-
Schema-learning and rebinding as mechanisms of in-context learning and emergence
Authors:
Sivaramakrishnan Swaminathan,
Antoine Dedieu,
Rajkumar Vasudeva Raju,
Murray Shanahan,
Miguel Lazaro-Gredilla,
Dileep George
Abstract:
In-context learning (ICL) is one of the most powerful and most unexpected capabilities to emerge in recent transformer-based large language models (LLMs). Yet the mechanisms that underlie it are poorly understood. In this paper, we demonstrate that comparable ICL capabilities can be acquired by an alternative sequence prediction learning method using clone-structured causal graphs (CSCGs). Moreove…
▽ More
In-context learning (ICL) is one of the most powerful and most unexpected capabilities to emerge in recent transformer-based large language models (LLMs). Yet the mechanisms that underlie it are poorly understood. In this paper, we demonstrate that comparable ICL capabilities can be acquired by an alternative sequence prediction learning method using clone-structured causal graphs (CSCGs). Moreover, a key property of CSCGs is that, unlike transformer-based LLMs, they are {\em interpretable}, which considerably simplifies the task of explaining how ICL works. Specifically, we show that it uses a combination of (a) learning template (schema) circuits for pattern completion, (b) retrieving relevant templates in a context-sensitive manner, and (c) rebinding of novel tokens to appropriate slots in the templates. We go on to marshall evidence for the hypothesis that similar mechanisms underlie ICL in LLMs. For example, we find that, with CSCGs as with LLMs, different capabilities emerge at different levels of overparameterization, suggesting that overparameterization helps in learning more complex template (schema) circuits. By showing how ICL can be achieved with small models and datasets, we open up a path to novel architectures, and take a vital step towards a more general understanding of the mechanics behind this important capability.
△ Less
Submitted 15 June, 2023;
originally announced July 2023.
-
Learning noisy-OR Bayesian Networks with Max-Product Belief Propagation
Authors:
Antoine Dedieu,
Guangyao Zhou,
Dileep George,
Miguel Lazaro-Gredilla
Abstract:
Noisy-OR Bayesian Networks (BNs) are a family of probabilistic graphical models which express rich statistical dependencies in binary data. Variational inference (VI) has been the main method proposed to learn noisy-OR BNs with complex latent structures (Jaakkola & Jordan, 1999; Ji et al., 2020; Buhai et al., 2020). However, the proposed VI approaches either (a) use a recognition network with stan…
▽ More
Noisy-OR Bayesian Networks (BNs) are a family of probabilistic graphical models which express rich statistical dependencies in binary data. Variational inference (VI) has been the main method proposed to learn noisy-OR BNs with complex latent structures (Jaakkola & Jordan, 1999; Ji et al., 2020; Buhai et al., 2020). However, the proposed VI approaches either (a) use a recognition network with standard amortized inference that cannot induce ``explaining-away''; or (b) assume a simple mean-field (MF) posterior which is vulnerable to bad local optima. Existing MF VI methods also update the MF parameters sequentially which makes them inherently slow. In this paper, we propose parallel max-product as an alternative algorithm for learning noisy-OR BNs with complex latent structures and we derive a fast stochastic training scheme that scales to large datasets. We evaluate both approaches on several benchmarks where VI is the state-of-the-art and show that our method (a) achieves better test performance than Ji et al. (2020) for learning noisy-OR BNs with hierarchical latent structures on large sparse real datasets; (b) recovers a higher number of ground truth parameters than Buhai et al. (2020) from cluttered synthetic scenes; and (c) solves the 2D blind deconvolution problem from Lazaro-Gredilla et al. (2021) and variant - including binary matrix factorization - while VI catastrophically fails and is up to two orders of magnitude slower.
△ Less
Submitted 31 January, 2023;
originally announced February 2023.
-
PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX
Authors:
Guangyao Zhou,
Antoine Dedieu,
Nishanth Kumar,
Wolfgang Lehrach,
Miguel Lázaro-Gredilla,
Shrinu Kushagra,
Dileep George
Abstract:
PGMax is an open-source Python package for (a) easily specifying discrete Probabilistic Graphical Models (PGMs) as factor graphs; and (b) automatically running efficient and scalable loopy belief propagation (LBP) in JAX. PGMax supports general factor graphs with tractable factors, and leverages modern accelerators like GPUs for inference. Compared with existing alternatives, PGMax obtains higher-…
▽ More
PGMax is an open-source Python package for (a) easily specifying discrete Probabilistic Graphical Models (PGMs) as factor graphs; and (b) automatically running efficient and scalable loopy belief propagation (LBP) in JAX. PGMax supports general factor graphs with tractable factors, and leverages modern accelerators like GPUs for inference. Compared with existing alternatives, PGMax obtains higher-quality inference results with up to three orders-of-magnitude inference time speedups. PGMax additionally interacts seamlessly with the rapidly growing JAX ecosystem, opening up new research possibilities. Our source code, examples and documentation are available at https://github.com/deepmind/PGMax.
△ Less
Submitted 24 March, 2023; v1 submitted 8 February, 2022;
originally announced February 2022.
-
Graphical Models with Attention for Context-Specific Independence and an Application to Perceptual Grouping
Authors:
Guangyao Zhou,
Wolfgang Lehrach,
Antoine Dedieu,
Miguel Lázaro-Gredilla,
Dileep George
Abstract:
Discrete undirected graphical models, also known as Markov Random Fields (MRFs), can flexibly encode probabilistic interactions of multiple variables, and have enjoyed successful applications to a wide range of problems. However, a well-known yet little studied limitation of discrete MRFs is that they cannot capture context-specific independence (CSI). Existing methods require carefully developed…
▽ More
Discrete undirected graphical models, also known as Markov Random Fields (MRFs), can flexibly encode probabilistic interactions of multiple variables, and have enjoyed successful applications to a wide range of problems. However, a well-known yet little studied limitation of discrete MRFs is that they cannot capture context-specific independence (CSI). Existing methods require carefully developed theories and purpose-built inference methods, which limit their applications to only small-scale problems. In this paper, we propose the Markov Attention Model (MAM), a family of discrete MRFs that incorporates an attention mechanism. The attention mechanism allows variables to dynamically attend to some other variables while ignoring the rest, and enables capturing of CSIs in MRFs. A MAM is formulated as an MRF, allowing it to benefit from the rich set of existing MRF inference methods and scale to large models and datasets. To demonstrate MAM's capabilities to capture CSIs at scale, we apply MAMs to capture an important type of CSI that is present in a symbolic approach to recurrent computations in perceptual grouping. Experiments on two recently proposed synthetic perceptual grouping tasks and on realistic images demonstrate the advantages of MAMs in sample-efficiency, interpretability and generalizability when compared with strong recurrent neural network baselines, and validate MAM's capabilities to efficiently capture CSIs at scale.
△ Less
Submitted 6 December, 2021;
originally announced December 2021.
-
Perturb-and-max-product: Sampling and learning in discrete energy-based models
Authors:
Miguel Lazaro-Gredilla,
Antoine Dedieu,
Dileep George
Abstract:
Perturb-and-MAP offers an elegant approach to approximately sample from a energy-based model (EBM) by computing the maximum-a-posteriori (MAP) configuration of a perturbed version of the model. Sampling in turn enables learning. However, this line of research has been hindered by the general intractability of the MAP computation. Very few works venture outside tractable models, and when they do, t…
▽ More
Perturb-and-MAP offers an elegant approach to approximately sample from a energy-based model (EBM) by computing the maximum-a-posteriori (MAP) configuration of a perturbed version of the model. Sampling in turn enables learning. However, this line of research has been hindered by the general intractability of the MAP computation. Very few works venture outside tractable models, and when they do, they use linear programming approaches, which as we will show, have several limitations. In this work we present perturb-and-max-product (PMP), a parallel and scalable mechanism for sampling and learning in discrete EBMs. Models can be arbitrary as long as they are built using tractable factors. We show that (a) for Ising models, PMP is orders of magnitude faster than Gibbs and Gibbs-with-Gradients (GWG) at learning and generating samples of similar or better quality; (b) PMP is able to learn and sample from RBMs; (c) in a large, entangled graphical model in which Gibbs and GWG fail to mix, PMP succeeds.
△ Less
Submitted 5 November, 2021; v1 submitted 3 November, 2021;
originally announced November 2021.
-
Sample-Efficient L0-L2 Constrained Structure Learning of Sparse Ising Models
Authors:
Antoine Dedieu,
Miguel Lázaro-Gredilla,
Dileep George
Abstract:
We consider the problem of learning the underlying graph of a sparse Ising model with $p$ nodes from $n$ i.i.d. samples. The most recent and best performing approaches combine an empirical loss (the logistic regression loss or the interaction screening loss) with a regularizer (an L1 penalty or an L1 constraint). This results in a convex problem that can be solved separately for each node of the g…
▽ More
We consider the problem of learning the underlying graph of a sparse Ising model with $p$ nodes from $n$ i.i.d. samples. The most recent and best performing approaches combine an empirical loss (the logistic regression loss or the interaction screening loss) with a regularizer (an L1 penalty or an L1 constraint). This results in a convex problem that can be solved separately for each node of the graph. In this work, we leverage the cardinality constraint L0 norm, which is known to properly induce sparsity, and further combine it with an L2 norm to better model the non-zero coefficients. We show that our proposed estimators achieve an improved sample complexity, both (a) theoretically, by reaching new state-of-the-art upper bounds for recovery guarantees, and (b) empirically, by showing sharper phase transitions between poor and full recovery for graph topologies studied in the literature, when compared to their L1-based state-of-the-art methods.
△ Less
Submitted 15 September, 2021; v1 submitted 3 December, 2020;
originally announced December 2020.
-
Query Training: Learning a Worse Model to Infer Better Marginals in Undirected Graphical Models with Hidden Variables
Authors:
Miguel Lázaro-Gredilla,
Wolfgang Lehrach,
Nishad Gothoskar,
Guangyao Zhou,
Antoine Dedieu,
Dileep George
Abstract:
Probabilistic graphical models (PGMs) provide a compact representation of knowledge that can be queried in a flexible way: after learning the parameters of a graphical model once, new probabilistic queries can be answered at test time without retraining. However, when using undirected PGMS with hidden variables, two sources of error typically compound in all but the simplest models (a) learning er…
▽ More
Probabilistic graphical models (PGMs) provide a compact representation of knowledge that can be queried in a flexible way: after learning the parameters of a graphical model once, new probabilistic queries can be answered at test time without retraining. However, when using undirected PGMS with hidden variables, two sources of error typically compound in all but the simplest models (a) learning error (both computing the partition function and integrating out the hidden variables is intractable); and (b) prediction error (exact inference is also intractable). Here we introduce query training (QT), a mechanism to learn a PGM that is optimized for the approximate inference algorithm that will be paired with it. The resulting PGM is a worse model of the data (as measured by the likelihood), but it is tuned to produce better marginals for a given inference algorithm. Unlike prior works, our approach preserves the querying flexibility of the original PGM: at test time, we can estimate the marginal of any variable given any partial evidence. We demonstrate experimentally that QT can be used to learn a challenging 8-connected grid Markov random field with hidden variables and that it consistently outperforms the state-of-the-art AdVIL when tested on three undirected models across multiple datasets.
△ Less
Submitted 25 February, 2021; v1 submitted 11 June, 2020;
originally announced June 2020.
-
Learning Sparse Classifiers: Continuous and Mixed Integer Optimization Perspectives
Authors:
Antoine Dedieu,
Hussein Hazimeh,
Rahul Mazumder
Abstract:
We consider a discrete optimization formulation for learning sparse classifiers, where the outcome depends upon a linear combination of a small subset of features. Recent work has shown that mixed integer programming (MIP) can be used to solve (to optimality) $\ell_0$-regularized regression problems at scales much larger than what was conventionally considered possible. Despite their usefulness, M…
▽ More
We consider a discrete optimization formulation for learning sparse classifiers, where the outcome depends upon a linear combination of a small subset of features. Recent work has shown that mixed integer programming (MIP) can be used to solve (to optimality) $\ell_0$-regularized regression problems at scales much larger than what was conventionally considered possible. Despite their usefulness, MIP-based global optimization approaches are significantly slower compared to the relatively mature algorithms for $\ell_1$-regularization and heuristics for nonconvex regularized problems. We aim to bridge this gap in computation times by developing new MIP-based algorithms for $\ell_0$-regularized classification. We propose two classes of scalable algorithms: an exact algorithm that can handle $p\approx 50,000$ features in a few minutes, and approximate algorithms that can address instances with $p\approx 10^6$ in times comparable to the fast $\ell_1$-based algorithms. Our exact algorithm is based on the novel idea of \textsl{integrality generation}, which solves the original problem (with $p$ binary variables) via a sequence of mixed integer programs that involve a small number of binary variables. Our approximate algorithms are based on coordinate descent and local combinatorial search. In addition, we present new estimation error bounds for a class of $\ell_0$-regularized estimators. Experiments on real and synthetic data demonstrate that our approach leads to models with considerably improved statistical performance (especially, variable selection) when compared to competing methods.
△ Less
Submitted 6 June, 2021; v1 submitted 17 January, 2020;
originally announced January 2020.
-
An error bound for Lasso and Group Lasso in high dimensions
Authors:
Antoine Dedieu
Abstract:
We leverage recent advances in high-dimensional statistics to derive new L2 estimation upper bounds for Lasso and Group Lasso in high-dimensions. For Lasso, our bounds scale as $(k^*/n) \log(p/k^*)$---$n\times p$ is the size of the design matrix and $k^*$ the dimension of the ground truth $\boldsymbolβ^*$---and match the optimal minimax rate. For Group Lasso, our bounds scale as…
▽ More
We leverage recent advances in high-dimensional statistics to derive new L2 estimation upper bounds for Lasso and Group Lasso in high-dimensions. For Lasso, our bounds scale as $(k^*/n) \log(p/k^*)$---$n\times p$ is the size of the design matrix and $k^*$ the dimension of the ground truth $\boldsymbolβ^*$---and match the optimal minimax rate. For Group Lasso, our bounds scale as $(s^*/n) \log\left( G / s^* \right) + m^* / n$---$G$ is the total number of groups and $m^*$ the number of coefficients in the $s^*$ groups which contain $\boldsymbolβ^*$---and improve over existing results. We additionally show that when the signal is strongly group-sparse, Group Lasso is superior to Lasso.
△ Less
Submitted 26 February, 2020; v1 submitted 21 December, 2019;
originally announced December 2019.
-
Improved error rates for sparse (group) learning with Lipschitz loss functions
Authors:
Antoine Dedieu
Abstract:
We study a family of sparse estimators defined as minimizers of some empirical Lipschitz loss function -- which include the hinge loss, the logistic loss and the quantile regression loss -- with a convex, sparse or group-sparse regularization. In particular, we consider the L1 norm on the coefficients, its sorted Slope version, and the Group L1-L2 extension. We propose a new theoretical framework…
▽ More
We study a family of sparse estimators defined as minimizers of some empirical Lipschitz loss function -- which include the hinge loss, the logistic loss and the quantile regression loss -- with a convex, sparse or group-sparse regularization. In particular, we consider the L1 norm on the coefficients, its sorted Slope version, and the Group L1-L2 extension. We propose a new theoretical framework that uses common assumptions in the literature to simultaneously derive new high-dimensional L2 estimation upper bounds for all three regularization schemes. %, and to improve over existing results. For L1 and Slope regularizations, our bounds scale as $(k^*/n) \log(p/k^*)$ -- $n\times p$ is the size of the design matrix and $k^*$ the dimension of the theoretical loss minimizer $\Bβ^*$ -- and match the optimal minimax rate achieved for the least-squares case. For Group L1-L2 regularization, our bounds scale as $(s^*/n) \log\left( G / s^* \right) + m^* / n$ -- $G$ is the total number of groups and $m^*$ the number of coefficients in the $s^*$ groups which contain $\Bβ^*$ -- and improve over the least-squares case. We show that, when the signal is strongly group-sparse, Group L1-L2 is superior to L1 and Slope. In addition, we adapt our approach to the sub-Gaussian linear regression framework and reach the optimal minimax rate for Lasso, and an improved rate for Group-Lasso. Finally, we release an accelerated proximal algorithm that computes the nine main convex estimators of interest when the number of variables is of the order of $100,000s$.
△ Less
Submitted 21 September, 2021; v1 submitted 19 October, 2019;
originally announced October 2019.
-
Learning higher-order sequential structure with cloned HMMs
Authors:
Antoine Dedieu,
Nishad Gothoskar,
Scott Swingle,
Wolfgang Lehrach,
Miguel Lázaro-Gredilla,
Dileep George
Abstract:
Variable order sequence modeling is an important problem in artificial and natural intelligence. While overcomplete Hidden Markov Models (HMMs), in theory, have the capacity to represent long-term temporal structure, they often fail to learn and converge to local minima. We show that by constraining HMMs with a simple sparsity structure inspired by biology, we can make it learn variable order sequ…
▽ More
Variable order sequence modeling is an important problem in artificial and natural intelligence. While overcomplete Hidden Markov Models (HMMs), in theory, have the capacity to represent long-term temporal structure, they often fail to learn and converge to local minima. We show that by constraining HMMs with a simple sparsity structure inspired by biology, we can make it learn variable order sequences efficiently. We call this model cloned HMM (CHMM) because the sparsity structure enforces that many hidden states map deterministically to the same emission state. CHMMs with over 1 billion parameters can be efficiently trained on GPUs without being severely affected by the credit diffusion problem of standard HMMs. Unlike n-grams and sequence memoizers, CHMMs can model temporal dependencies at arbitrarily long distances and recognize contexts with 'holes' in them. Compared to Recurrent Neural Networks and their Long Short-Term Memory extensions (LSTMs), CHMMs are generative models that can natively deal with uncertainty. Moreover, CHMMs return a higher-order graph that represents the temporal structure of the data which can be useful for community detection, and for building hierarchical models. Our experiments show that CHMMs can beat n-grams, sequence memoizers, and LSTMs on character-level language modeling tasks. CHMMs can be a viable alternative to these methods in some tasks that require variable order sequence modeling and the handling of uncertainty.
△ Less
Submitted 15 May, 2019; v1 submitted 1 May, 2019;
originally announced May 2019.
-
Solving L1-regularized SVMs and related linear programs: Revisiting the effectiveness of Column and Constraint Generation
Authors:
Antoine Dedieu,
Rahul Mazumder,
Haoyue Wang
Abstract:
The linear Support Vector Machine (SVM) is a classic classification technique in machine learning. Motivated by applications in modern high dimensional statistics, we consider penalized SVM problems involving the minimization of a hinge-loss function with a convex sparsity-inducing regularizer such as: the L1-norm on the coefficients, its grouped generalization and the sorted L1-penalty (aka Slope…
▽ More
The linear Support Vector Machine (SVM) is a classic classification technique in machine learning. Motivated by applications in modern high dimensional statistics, we consider penalized SVM problems involving the minimization of a hinge-loss function with a convex sparsity-inducing regularizer such as: the L1-norm on the coefficients, its grouped generalization and the sorted L1-penalty (aka Slope). Each problem can be expressed as a Linear Program (LP) and is computationally challenging when the number of features and/or samples is large -- the current state of algorithms for these problems is rather nascent when compared to the usual L2-regularized linear SVM. To this end, we propose new computational algorithms for these LPs by bringing together techniques from (a) classical column (and constraint) generation methods and (b) first order methods for non-smooth convex optimization -- techniques that are rarely used together for solving large scale LPs. These components have their respective strengths; and while they are found to be useful as separate entities, they have not been used together in the context of solving large scale LPs such as the ones studied herein. Our approach complements the strengths of (a) and (b) -- leading to a scheme that seems to significantly outperform commercial solvers as well as specialized implementations for these problems. We present numerical results on a series of real and synthetic datasets demonstrating the surprising effectiveness of classic column/constraint generation methods in the context of challenging LP-based machine learning tasks.
△ Less
Submitted 27 August, 2021; v1 submitted 6 January, 2019;
originally announced January 2019.
-
Hierarchical Modeling and Shrinkage for User Session Length Prediction in Media Streaming
Authors:
Antoine Dedieu,
Rahul Mazumder,
Zhen Zhu,
Hossein Vahabi
Abstract:
An important metric of users' satisfaction and engagement within on-line streaming services is the user session length, i.e. the amount of time they spend on a service continuously without interruption. Being able to predict this value directly benefits the recommendation and ad pacing contexts in music and video streaming services. Recent research has shown that predicting the exact amount of tim…
▽ More
An important metric of users' satisfaction and engagement within on-line streaming services is the user session length, i.e. the amount of time they spend on a service continuously without interruption. Being able to predict this value directly benefits the recommendation and ad pacing contexts in music and video streaming services. Recent research has shown that predicting the exact amount of time spent is highly nontrivial due to many external factors for which a user can end a session, and the lack of predictive covariates. Most of the other related literature on duration based user engagement has focused on dwell time for websites, for search and display ads, mainly for post-click satisfaction prediction or ad ranking.
In this work we present a novel framework inspired by hierarchical Bayesian modeling to predict, at the moment of login, the amount of time a user will spend in the streaming service. The time spent by a user on a platform depends upon user-specific latent variables which are learned via hierarchical shrinkage. Our framework enjoys theoretical guarantees and naturally incorporates flexible parametric/nonparametric models on the covariates, including models robust to outliers. Our proposal is found to outperform state-of- the-art estimators in terms of efficiency and predictive performance on real world public and private datasets.
△ Less
Submitted 22 June, 2018; v1 submitted 4 March, 2018;
originally announced March 2018.