MotionLM: Multi-Agent Motion Forecasting as Language Modeling Ari Seff Brian Cera Dian Chen* Mason Ng Aurick Zhou Nigamaa Nayakanti Khaled S. Refaat Rami Al-Rfou Benjamin Sapp Waymo Abstract Reliable forecasting of the future behavior of road agents is a critical component to safe planning in autonomous ve- hicles. Here, we represent continuous trajectories as se- quences of discrete motion tokens and cast multi-agent mo- tion prediction as a language modeling task over this do- main. Our model, MotionLM, provides several advantages: First, it does not require anchors or explicit latent variable optimization to learn multimodal distributions. Instead, we leverage a single standard language modeling objective, maximizing the average log probability over sequence to- kens. Second, our approach bypasses post-hoc interaction heuristics where individual agent trajectory generation is conducted prior to interactive scoring. Instead, MotionLM produces joint distributions over interactive agent futures in a single autoregressive decoding process. In addition, the model’s sequential factorization enables temporally causal conditional rollouts. The proposed approach establishes new state-of-the-art performance for multi-agent motion prediction on the Waymo Open Motion Dataset, ranking 1st on the interactive challenge leaderboard. 1. Introduction Modern sequence models often employ a next-token pre- diction objective that incorporates minimal domain-specific assumptions. For example, autoregressive language mod- els [3, 10] are pre-trained to maximize the probability of the next observed subword conditioned on the previous text; there is no predefined notion of parsing or syntax built in. This approach has found success in continuous domains as well, such as audio [2] and image generation [49]. Leverag- ing the flexibility of arbitrary categorical distributions, the above works represent continuous data with a set of discrete *Work done during an internship at Waymo. Contact: {aseff, bensapp}@waymo.com Vocabulary … … Motion token sequence: t=T t=1 t=2 ... Figure 1. Our model autoregressively generates sequences of dis- crete motion tokens for a set of agents to produce consistent inter- active trajectory forecasts. tokens, reminiscent of language model vocabularies. In driving scenarios, road users may be likened to par- ticipants in a constant dialogue, continuously exchanging a dynamic series of actions and reactions mirroring the fluid- ity of communication. Navigating this rich web of interac- tions requires the ability to anticipate the likely maneuvers and responses of the involved actors. Just as today’s lan- guage models can capture sophisticated distributions over conversations, can we leverage similar sequence models to forecast the behavior of road agents? A common simplification to modeling the full future world state has been to decompose the joint distribution of agent behavior into independent per-agent marginal distri- butions. Although there has been much progress on this task [8, 47, 12, 25, 31, 5, 6, 21], marginal predictions are insufficient as inputs to a planning system; they do not rep- resent the future dependencies between the actions of differ- ent agents, leading to inconsistent scene-level forecasting. Of the existing joint prediction approaches, some apply arXiv:2309.16534v1 [cs.CV] 28 Sep 2023 a separation between marginal trajectory generation and in- teractive scoring [40, 42, 29]. For example, Luo et al. [29] initially produce a small set of marginal trajectories for each agent independently, before assigning a learned potential to each inter-agent trajectory pair through a belief propaga- tion algorithm. Sun et al. [42] use a manual heuristic to tag agents as either influencers or reactors, and then pairs marginal and conditional predictions to form joint predic- tions. We also note that because these approaches do not ex- plicitly model temporal dependencies within trajectories, their conditional forecasts may be more susceptible to spu- rious correlations, leading to less realistic reaction predic- tions. For example, these models can capture the cor- relation between a lead agent decelerating and a trail- ing agent decelerating, but may fail to infer which one is likely causing the other to slow down. In contrast, previ- ous joint models employing an autoregressive factorization, e.g., [36, 43, 39], do respect future temporal dependencies. These models have generally relied on explicit latent vari- ables for diversity, optimized via either an evidence lower bound or normalizing flow. In this work, we combine trajectory generation and in- teraction modeling in a single, temporally causal, decod- ing process over discrete motion tokens (Fig. 1), leverag- ing a simple training objective inspired by autoregressive language models. Our model, MotionLM, is trained to directly maximize the log probability of these token se- quences among interacting agents. At inference time, joint trajectories are produced step-by-step, where interacting agents sample tokens simultaneously, attend to one another, and repeat. In contrast to previous approaches which man- ually enforce trajectory multimodality during training, our model is entirely latent variable and anchor-free, with mul- timodality emerging solely as a characteristic of sampling. MotionLM may be applied to several downstream behavior prediction tasks, including marginal, joint, and conditional predictions. This work makes the following contributions: 1. We cast multi-agent motion forecasting as a language modeling task, introducing a temporally causal de- coder over discrete motion tokens trained with a causal language modeling loss. 2. We pair sampling from our model with a simple roll- out aggregation scheme that facilitates weighted mode identification for joint trajectories, establishing new state-of-the-art performance on the Waymo Open Mo- tion Dataset interaction prediction challenge (6% im- provement in the ranking joint mAP metric). 3. We perform extensive ablations of our approach as well as analysis of its temporally causal conditional predictions, which are largely unsupported by current joint forecasting models. 2. Related work Marginal trajectory prediction. Behavior predictors are often evaluated on their predictions for individual agents, e.g., in recent motion forecasting benchmarks [14, 9, 4, 51, 37]. Previous methods process the rasterized scene with CNNs [8, 5, 12, 17]; the more recent works repre- sent scenes with points and polygraphs and process them with GNNs [6, 25, 47, 22] or transformers [31, 40, 20]. To handle the multimodality of future trajectories, some mod- els manually enforce diversity via predefined anchors [8, 5] or intention points [40, 52, 28]. Other works learn diverse modes with latent variable modeling, e.g., [24]. While these works produce multimodal future trajecto- ries of individual agents, they only capture the marginal dis- tributions of the possible agent futures and do not model the interactions among agents. Interactive trajectory prediction. Interactive behavior predictors model the joint distribution of agents’ futures. This task has been far less studied than marginal motion prediction. For example, the Waymo Open Motion Dataset (WOMD) [14] challenge leaderboard currently has 71 pub- lished entries for marginal prediction compared to only 14 for interaction prediction. Ngiam et al. [32] models the distribution of future trajec- tories with a transformer-based mixture model outputting joint modes. To avoid the exponential blow-up from a full joint model, Luo et al. [29] models pairwise joint distri- butions. Tolstaya et al. [44], Song et al. [41], Sun et al. [42] consider conditional predictions by exposing the future trajectory of one agent when predicting for another agent. Shi et al. [40] derives joint probabilities by simply mul- tiplying marginal trajectory probabilities, essentially treat- ing agents as independent, which may limit accuracy. Cui et al. [11], Casas et al. [7], Girgis et al. [15] reduce the full- fledged joint distribution using global latent variables. Un- like our autoregressive factorization, the above models typi- cally follow “one-shot” (parallel across time) factorizations and do not explicitly model temporally causal interactions. Autoregressive trajectory prediction. Autoregressive behavior predictors generate trajectories at intervals to pro- duce scene-consistent multi-agent trajectories. Rhinehart et al. [36], Tang and Salakhutdinov [43], Amirloo et al. [1], Salzmann et al. [39], Yuan et al. [50] predict multi- agent future trajectories using latent variable models. Lu et al. [27] explores autoregressively outputting keyframes via mixtures of Gaussians prior to filling in the remaining states. In [18], an adversarial objective is combined with Scene Encoder Projection Multimodal Scene Inputs Scene embeddings … … … … Agent 1 Agent 2 Time Start token t=0 t=1 Cross Attention Autoregressive Transformer Decoder Sample [R,N,ᐧ,H] Sample A1 A1 … t=T-1 … Sample Ensemble & Rollout aggregation Self Attention [R,N,T,2] xE [K,N,T,2] NMS K-means Decoded motion tokens Embed Embed AT AT AT-1 AT-1 t=T Figure 2. MotionLM architecture. We first encode heterogeneous scene features relative to each modeled agent (left) as scene embeddings of shape R, N, ·, H. Here, R refers to the number of rollouts, N refers to the number of (jointly modeled) agents, and H is the dimen- sionality of each embedding. We repeat the embeddings R times in the batch dimension for parallel sampling during inference. Next, a trajectory decoder autoregressively rolls out T discrete motion tokens for multiple agents in a temporally causal manner (center). Finally, representative modes of the rollouts may be recovered via a simple aggregation utilizing k-means clustering initialized with non-maximum suppression (right). parallel beam search to learn multi-agent rollouts. Unlike most autoregressive trajectory predictors, our method does not rely on latent variables or beam search and generates multimodal joint trajectories by directly sampling from a learned distribution of discrete motion token sequences. Discrete sequence modeling in continuous domains. When generating sequences in continuous domains, one ef- fective approach is to discretize the output space and predict categorical distributions at each step. For example, in image generation, van den Oord et al. [45] sequentially predict the uniformly discretized pixel val- ues for each channel and found this to perform better than outputting continuous values directly. Multiple works on generating images from text such as [35] and [49] use a two-stage process with a learned tokenizer to map images to discrete tokens and an autoregressive model to predict the discrete tokens given the text prompt. For audio gen- eration, WaveNet [46] applies a µ-law transformation be- fore discretizing. Borsos et al. [2] learn a hierarchical to- kenizer/detokenizer, with the main transformer sequence model operating on the intermediate discrete tokens. When generating polygonal meshes, Nash et al. [30] uniformly quantize the coordinates of each vertex. In MotionLM, we employ a simple uniform quantization of axis-aligned deltas between consecutive waypoints of agent trajectories. 3. MotionLM We aim to model a distribution over multi-agent inter- actions in a general manner that can be applied to distinct downstream tasks, including marginal, joint, and condi- tional forecasting. This requires an expressive generative framework capable of capturing the substantial multimodal- ity in driving scenarios. In addition, we take consideration here to preserve temporal dependencies; i.e., inference in our model follows a directed acyclic graph with the parents of every node residing earlier in time and children residing later (Section 3.3, Fig. 4). This enables conditional fore- casts that more closely resemble causal interventions [34] by eliminating certain spurious correlations that can other- wise result from disobeying temporal causality2. We ob- serve that joint models that do not preserve temporal de- pendencies may have a limited ability to predict realistic agent reactions – a key use in planning (Section 4.6). To this end, we leverage an autoregressive factorization of our fu- ture decoder, where agents’ motion tokens are conditionally dependent on all previously sampled tokens and trajectories are rolled out sequentially (Fig. 2). Let S represent the input data for a given scenario. This may include context such as roadgraph elements, traffic light states, as well as features describing road agents (e.g., vehicles, cyclists, and pedestrians) and their recent histo- ries, all provided at the current timestep t = 0. Our task is to generate predictions for joint agent states Yt .= {y1 t , y2 t , ..., yN t } for N agents of interest at future timesteps t = 1, ..., T. Rather than complete states, these future state targets are typically two-dimensional waypoints (i.e., (x, y) coordinates), with T waypoints forming the full ground truth trajectory for an individual agent. 2We make no claims that our model is capable of directly modeling causal relationships (due to the theoretical limits of purely observational data and unobserved confounders). Here, we solely take care to avoid breaking temporal causality. 3.1. Joint probabilistic rollouts In our modeling framework, we sample a predicted ac- tion for each target agent at each future timestep. These actions are formulated as discrete motion tokens from a fi- nite vocabulary, as described later in Section 3.2.2. Let an t represent the target action (derived from the ground truth waypoints) for the nth agent at time t, with At .= {a1 t, a2 t, ..., aN t } representing the set of target actions for all agents at time t. Factorization. We factorize the distribution over joint fu- ture action sequences as a product of conditionals: pθ(A1, A2, ...AT | S) = T Y t=1 pθ(At | A 99% of the WOMD dataset. Verlet-wrapped action space. Once the above delta ac- tion space has the Verlet wrapper applied, we only require 13 bins for each coordinate. This results in a total of 132 = 169 total discrete motion tokens that the model can select from the Cartesian product comprising the final vo- cabulary. Sequence lengths. For 8-second futures, the model out- puts 16 motion tokens for each agent (note that WOMD evaluates predictions at 2 Hz). For the two-agent interac- tive split, our flattened agent-time token sequences (Sec- tion 3.2.2) have length 2 × 16 = 32. B. Implementation details B.1. Scene encoder We follow the design of the early fusion network pro- posed by [31] as the scene encoding backbone of our model. The following hyperparameters are used: • Number of layers: 4 • Hidden size: 256 • Feed-forward network intermediate size: 1024 • Number of attention heads: 4 • Number of latent queries: 92 • Activation: ReLU B.2. Trajectory decoder To autoregressively decode motion token sequences, we utilize a causal transformer decoder that takes in the motion tokens as queries, and the scene encodings as context. We use the following model hyperparameters: • Number of layers: 4 • Hidden size: 256 N agents T timesteps Figure 8. Masked causal attention between two agents dur- ing training. We flatten the agent and time axes, leading to an NT × NT attention mask. The agents may attend to each other’s previous motion tokens (solid squares) but no future tokens (empty squares). • Feed-forward network intermediate size: 1024 • Number of attention heads: 4 • Activation: ReLU B.3. Optimization We train our model to maximize the likelihood of the ground truth motion token sequences via teacher forcing. We use the following training hyperparameters: • Number of training steps: 600000 • Batch size: 256 • Learning rate schedule: Linear decay • Initial learning rate: 0.0006 • Final learning rate: 0.0 • Optimizer: AdamW • Weight decay: 0.6 B.4. Inference We found nucleus sampling [16], commonly used with language models, to be helpful for improving sample qual- ity while maintaining diversity. Here we set the top-p pa- rameter to 0.95. C. Metrics descriptions C.1. WOMD metrics All metrics for the two WOMD [14] benchmarks are evaluated at three time steps (3, 5, and 8 seconds) and are averaged over all object types to obtain the final value. For joint metrics, a scene is attributed to an object class (vehicle, pedestrian, or cyclist) according to the least common type of agent that is present in that interaction, with cyclist being the rarest object class and vehicles being the most common. Up to 6 trajectories are produced by the models for each target agent in each scene, which are then used for metric evaluation. mAP & Soft mAP mAP measures precision of predic- tion likelihoods and is calculated by first bucketing ground truth futures of objects into eight discrete classes of intent: straight, straight-left, straight-right, left, right, left u-turn, right u-turn, and stationary. For marginal predictions, a prediction trajectory is con- sidered a “miss” if it exceeds a lateral or longitudinal error threshold at a specified timestep T. Similarly for joint pre- dictions, a prediction is considered a “miss” if none of the k joint predictions contains trajectories for all predicted ob- jects within a given lateral and longitudinal error threshold, with respect to the ground truth trajectories for each agent. Trajectory predictions classified as a miss are labeled as a false positive. In the event of multiple predictions satisfy- ing the miss criteria, consistent with object detection mAP metrics, only one true positive is allowed for each scene, as- signed to the highest confidence prediction. All other pre- dictions for the object are assigned a false positive. To compute the mAP metric, bucket entries are sorted and a P/R curve is computed for each bucket, averaging precision values over various likelihood thresholds for all intent buckets results in the final mAP value. Soft mAP differs only in the fact that additional matching predictions (other than the most likely match) are ignored instead of be- ing assigned a false positive, and so are not penalized in the metric computation. Miss rate Using the same definition of a “miss” described above for either marginal or joint predictions, miss rate is a measure of what fraction of scenarios fail to generate any predictions within the lateral and longitudinal error thresh- olds, relative to the ground truth future. minADE & minFDE minADE measures the Euclidean distance error averaged over all timesteps for the closest prediction, relative to ground truth. In contrast, minFDE considers only the distance error at the final timestep. For joint predictions, minADE and minFDE are calculated as the average value over both agents. C.2. Prediction overlap As described in [29], the WOMD [14] overlap met- ric only considers overlap between predictions and ground truth. Here we use a prediction overlap metric to assess scene-level consistency for joint models. Our implementa- tion is similar to [29], except we follow the convention of the WOMD challenge of only requiring models to gener- ate (x, y) waypoints; headings are inferred as in [14]. If the bounding boxes of two predicted agents collide at any timestep in a scene, that counts as an overlap/collision for that scene. The final prediction overlap rate is calculated as the sum of per-scene overlaps, averaged across the dataset. D. Additional evaluation Ablations. Tables 5 and 6 display joint prediction perfor- mance across varying interactive attention frequencies and numbers of rollouts, respectively. In addition to the ensem- bled model performance, single replica performance is eval- uated. Standard deviations are computed for each metric over 8 independently trained replicas. Scaling analysis. Table 7 displays the performance of different model sizes on the WOMD interactive split, all trained with the same optimization hyperparameters. We vary the number of layers, hidden size, and number of atten- tion heads in the encoder and decoder proportionally. Due to external constraints, in this study we only train a single replica for each parameter count. We observe that a model with 27M parameters overfits while 300K underfits. Both the 1M and 9M models perform decently. In this paper, our main results use 9M-parameter replicas. Latency analysis. Table 8 provides inference latency on the latest generation of GPUs across different numbers of rollouts. These were measured for a single-replica joint model rolling out two agents. E. Visualizations In the supplementary zip file, we have included GIF an- imations of the model’s greatest-probability predictions in various scenes. Each example below displays the associ- ated scene ID, which is also contained in the corresponding GIF filename. We describe the examples here. E.1. Marginal vs. Joint • Scene ID: 286a65c777726df3 Marginal: The turning vehicle and crossing cyclist collide. Joint: The vehicle yields to the cyclist before turning. • Scene ID: 440bbf422d08f4c0 Marginal: The turning vehicle collides with the cross- ing vehicle in the middle of the intersection. Joint: The turning vehicle yields and collision is avoided. • Scene ID: 38899bce1e306fb1 Marginal: The lane-changing vehicle gets rear-ended by the vehicle in the adjacent lane. Joint: The adjacent vehicle slows down to allow the lane-changing vehicle to complete the maneuver. Ensemble Single Replica Freq. (Hz) minADE (↓) minFDE (↓) MR (↓) mAP (↑) minADE (↓) minFDE (↓) MR (↓) mAP (↑) 0.125 0.9120 2.0634 0.4222 0.2007 1.0681 (0.011) 2.4783 (0.025) 0.5112 (0.007) 0.1558 (0.007) 0.25 0.9083 2.0466 0.4241 0.1983 1.0630 (0.009) 2.4510 (0.025) 0.5094 (0.006) 0.1551 (0.006) 0.5 0.8931 2.0073 0.4173 0.2077 1.0512 (0.009) 2.4263 (0.022) 0.5039 (0.006) 0.1588 (0.004) 1 0.8842 1.9898 0.4117 0.2040 1.0419 (0.014) 2.4062 (0.032) 0.5005 (0.008) 0.1639 (0.005) 2 0.8831 1.9825 0.4092 0.2150 1.0345 (0.012) 2.3886 (0.031) 0.4943 (0.006) 0.1687 (0.004) Table 5. Joint prediction performance across varying interactive attention frequencies on the WOMD interactive validation set. Displayed are scene-level joint evaluation metrics. For the single replica metrics, we include the standard deviation (across 8 replicas) in parentheses. Ensemble Single Replica # Rollouts minADE (↓) minFDE (↓) MR (↓) mAP (↑) minADE (↓) minFDE (↓) MR (↓) mAP (↑) 1 1.0534 2.3526 0.5370 0.1524 1.9827 (0.018) 4.7958 (0.054) 0.8182 (0.003) 0.0578 (0.004) 2 0.9952 2.2172 0.4921 0.1721 1.6142 (0.011) 3.8479 (0.032) 0.7410 (0.003) 0.0827 (0.004) 4 0.9449 2.1100 0.4561 0.1869 1.3655 (0.012) 3.2060 (0.035) 0.6671 (0.003) 0.1083 (0.003) 8 0.9158 2.0495 0.4339 0.1934 1.2039 (0.013) 2.7848 (0.035) 0.5994 (0.004) 0.1324 (0.003) 16 0.9010 2.0163 0.4196 0.2024 1.1254 (0.012) 2.5893 (0.031) 0.5555 (0.005) 0.1457 (0.003) 32 0.8940 2.0041 0.4141 0.2065 1.0837 (0.013) 2.4945 (0.035) 0.5272 (0.005) 0.1538 (0.004) 64 0.8881 1.9888 0.4095 0.2051 1.0585 (0.012) 2.4411 (0.033) 0.5114 (0.005) 0.1585 (0.004) 128 0.8851 1.9893 0.4103 0.2074 1.0456 (0.012) 2.4131 (0.033) 0.5020 (0.006) 0.1625 (0.004) 256 0.8856 1.9893 0.4078 0.2137 1.0385 (0.012) 2.3984 (0.031) 0.4972 (0.007) 0.1663 (0.005) 512 0.8831 1.9825 0.4092 0.2150 1.0345 (0.012) 2.3886 (0.031) 0.4943 (0.006) 0.1687 (0.004) Table 6. Joint prediction performance across varying numbers of rollouts per replica on the WOMD interactive validation set. Displayed are scene-level joint evaluation metrics. For the single replica metrics, we include the standard deviation (across 8 replicas) in parentheses. Parameter count Miss Rate (↓) mAP (↑) 300K 0.6047 0.1054 1M 0.5037 0.1713 9M 0.4972 0.1663 27M 0.6072 0.1376 Table 7. Joint prediction performance across varying model sizes on the WOMD interactive validation set. Displayed are scene- level joint mAP and miss rate for 256 rollouts for a single model replica (except for 9M which displays the mean performance of 8 replicas). • Scene ID: 2ea76e74b5025ec7 Marginal: The cyclist crosses in front of the vehicle leading to a collision. Joint: The cyclist waits for the vehicle to proceed be- fore turning. • Scene ID: 55b5fe989aa4644b Marginal: The cyclist lane changes in front of the ad- jacent vehicle, leading to collision. Joint: The cyclist remains in their lane for the duration of the scene, avoiding collision. Number of rollouts Latency (ms) 16 19.9 (0.19) 32 27.5 (0.25) 64 43.8 (0.26) 128 75.8 (0.23) 256 137.7 (0.19) Table 8. Inference latency on current generation of GPUs for dif- ferent numbers of rollouts of the joint model. We display the mean and standard deviation (in parentheses) of the latency measure- ments for each setting. E.2. Marginal vs. Conditional “Conditional” here refers to temporally causal condition- ing as described in the main text. • Scene ID: 5ebba77f351358e2 Marginal: The pedestrian crosses the street as a vehi- cle is turning, leading to a collision. Conditional: When conditioning on the vehicle’s turning trajectory as a query, the pedestrian is instead predicted to remain stationary. • Scene ID: d557eee96705c822 Marginal: The modeled vehicle collides with the lead vehicle. Conditional: When conditioning on the lead vehicle’s query trajectory, which remains stationary for a bit, the modeled vehicle instead comes to a an appropriate stop. • Scene ID: 9410e72c551f0aec Marginal: The modeled vehicle takes the turn slowly, unaware of the last turning vehicle’s progress. Conditional: When conditioning on the query vehi- cle’s turn progress, the modeled agent likewise makes more progress. • Scene ID: c204982298bda1a1 Marginal: The modeled vehicle proceeds slowly, un- aware of the merging vehicle’s progress. Conditional: When conditioning on the query vehi- cle’s merge progress, the modeled agent accelerates behind. E.3. Temporally Causal vs. Acausal Conditioning • Scene ID: 4f39d4eb35a4c07c Joint prediction: The two modeled vehicles maintain speed for the duration of the scene. Conditioning on trailing agent: - Temporally causal: The lead vehicle is indifferent to the query trailing vehicle decelerating to a stop, pro- ceeding along at a constant speed. - Acausal: The lead vehicle is “influenced” by the query vehicle decelerating. It likewise comes to a stop. Intuitively, this is an incorrect direction of influence that the acausal model has learned. Conditioning on lead agent: - Temporally causal: When conditioning on the query lead vehicle decelerating to a stop, the modeled trail- ing vehicle is likewise predicted to stop. -Acausal: In this case, the acausal conditional pre- diction is similar to the temporally causal conditional. The trailing vehicle is predicted to stop behind the query lead vehicle.