Gradient-based planning for world fashions at longer horizons


BallNav demo
Push-T demo

By Michael Psenka, Mike Rabbat, Aditi Krishnapriyan, Yann LeCun, Amir Bar

GRASP is a brand new gradient-based planner for discovered dynamics (a “world mannequin”) that makes long-horizon planning sensible by (1) lifting the trajectory into digital states so optimization is parallel throughout time, (2) including stochasticity on to the state iterates for exploration, and (3) reshaping gradients so actions get clear indicators whereas we keep away from brittle “state-input” gradients by means of high-dimensional imaginative and prescient fashions.

Giant, discovered world fashions have gotten more and more succesful. They’ll predict lengthy sequences of future observations in high-dimensional visible areas and generalize throughout duties in ways in which had been tough to think about a couple of years in the past. As these fashions scale, they begin to look much less like task-specific predictors and extra like general-purpose simulators.

However having a strong predictive mannequin isn’t the identical as with the ability to use it successfully for management/studying/planning. In follow, long-horizon planning with fashionable world fashions stays fragile: optimization turns into ill-conditioned, non-greedy construction creates unhealthy native minima, and high-dimensional latent areas introduce refined failure modes.

On this weblog submit, I describe the issues that motivated this mission and our method to deal with them: why planning with fashionable world fashions might be surprisingly fragile, why lengthy horizons are the actual stress take a look at, and what we modified to make gradient-based planning far more strong.


This weblog submit discusses work achieved with Mike Rabbat, Aditi Krishnapriyan, Yann LeCun, and Amir Bar (* denotes equal advisorship), the place we suggest GRASP.


What’s a world mannequin?

Today, the time period “world mannequin” is kind of overloaded, and relying on the context can both imply an specific dynamics mannequin or some implicit, dependable inside state {that a} generative mannequin depends on (e.g. when an LLM generates chess strikes, whether or not there’s some inside illustration of the board). We give our free working definition beneath.

Suppose you are taking actions a_t in mathcal{A} and observe states s_t in mathcal{S} (pictures, latent vectors, proprioception). A world mannequin is a discovered mannequin that, given the present state and a sequence of future actions, predicts what is going to occur subsequent. Formally, it defines a predictive distribution on a sequence of noticed states s_{t-h:t} and present motion a_t:

    [P_theta(s_{t+1} mid s_{t-h:t},; a_t)]

that approximates the setting’s true conditional P(s_{t+1} mid s_{t-h:t},; a_t). For this weblog submit, we’ll assume a Markovian mannequin P(s_{t+1} mid s_{t-h:t},; a_t) for simplicity (all outcomes right here might be prolonged to the extra normal case), and when the mannequin is deterministic it reduces to a map over states:

    [s_{t+1} = F_theta(s_t, a_t).]

In follow the state s_t is usually a discovered latent illustration (e.g., encoded from pixels), so the mannequin operates in a (theoretically) compact, differentiable house. The important thing level is {that a} world mannequin offers you a differentiable simulator; you may roll it ahead underneath hypothetical motion sequences and backpropagate by means of the predictions.


Planning: selecting actions by optimizing by means of the mannequin

Given a begin s_0 and a aim g, the only planner chooses an motion sequence mathbf{a}=(a_0,dots,a_{T-1}) by rolling out the mannequin and minimizing terminal error:

    [min_{mathbf{a}} ; | s_T(mathbf{a}) - g |_2^2, quad text{where } s_T(mathbf{a}) = mathcal{F}_{theta}^{T}(s_0,mathbf{a}).]

Right here we use mathcal{F}^T as shorthand for the complete rollout by means of the world mannequin (dependence on mannequin parameters theta is implicit):

    [mathcal{F}_{theta}^{T}(s_0, mathbf{a}) = F_theta(F_theta(cdots F_theta(s_0, a_0), cdots, a_{T-2}), a_{T-1}).]

Briefly horizons and low-dimensional techniques, this could work fairly properly. However as horizons develop and fashions develop into bigger and extra expressive, its weaknesses develop into amplified.

So why doesn’t this simply work at scale?


Why long-horizon planning is tough (even when all the things is differentiable)

There are two separate ache factors for the extra normal world mannequin, plus a 3rd that’s particular to discovered, deep learning-based fashions.

1) Lengthy-horizon rollouts create deep, ill-conditioned computation graphs

These accustomed to backprop by means of time (BPTT) could discover that we’re differentiating by means of a mannequin utilized to itself repeatedly, which is able to result in the exploding/vanishing gradients downside. Particularly, if we take derivatives (observe we’re differentiating vector-valued features, leading to Jacobians that we denote with D_x (cdots)) with respect to earlier actions (e.g. a_0):

    [D_{a_0} mathcal{F}_{theta}^{T}(s_0, mathbf{a}) = Bigl(prod_{t=1}^T D_s F_theta(s_t, a_t)Bigr) D_{a_0}F_theta(s_0, a_0).]

We see that the Jacobian’s conditioning scales exponentially with time T:

    [sigma_{text{max/min}}(D_{a_0}mathcal{F}_{theta}^{T}) sim sigma_{text{max/min}}(D_s F_theta)^{T-1},]

resulting in exploding or vanishing gradients.

2) The panorama is non-greedy and stuffed with traps

At brief horizons, the grasping answer, the place we transfer straight towards the aim at each step, is usually ok. In the event you solely must plan a couple of steps forward, the optimum trajectory normally doesn’t deviate a lot from “head towards g” at every step.

As horizons develop, two issues occur. First, longer duties usually tend to require non-greedy conduct: going round a wall, repositioning earlier than pushing, backing as much as take a greater path. And as horizons develop, extra of those non-greedy steps are usually wanted. Second, the optimization house itself scales with horizon: mathrm{dim}(mathcal{A} times cdots times mathcal{A}) = Tmathrm{dim}(mathcal{A}), additional increasing the house of native minima for the optimization downside.

Loss landscape
Distance to aim alongside the optimum path is non-monotonic, and the ensuing loss panorama might be tough.


A protracted-horizon repair: lifting the dynamics constraint

Suppose we deal with the dynamics constraint s_{t+1} = F_{theta}(s_t, a_t) as a gentle constraint, and we as an alternative optimize the next penalty perform over each actions (a_0,ldots,a_{T-1}) and states (s_0,ldots,s_T):

    [min_{mathbf{s},mathbf{a}} mathcal{L}(mathbf{s}, mathbf{a}) = sum_{t=0}^{T-1} big|F_theta(s_t,a_t) - s_{t+1}big|_2^2, quad text{with } s_0 text{ fixed and } s_T=g.]

That is additionally typically known as collocation in planning/robotics literature. Be aware the lifted formulation shares the identical international minimizers as the unique rollout goal (each are zero precisely when the trajectory is dynamically possible). However the optimization landscapes are very completely different, and we get two quick advantages:

  • Every world mannequin analysis F_{theta}(s_t,a_t) relies upon solely on native variables, so all T phrases might be computed in parallel throughout time, leading to an enormous speed-up for longer horizons, and
  • You not backpropagate by means of a single deep T-step composition to get a studying sign, for the reason that earlier product of Jacobians now splits right into a sum, e.g.:

    [D_{a_0} mathcal{L} = 2(F_theta(s_0, a_0) - s_1).]

Having the ability to optimize states immediately additionally helps with exploration, as we will briefly navigate by means of unphysical domains to search out the optimum plan:

Collocation planning in BallNav
Collocation-based planning permits us to immediately perturb states and discover midpoints extra successfully.

Nevertheless, lunch isn’t free. And certainly, particularly for deep learning-based world fashions, there’s a essential concern that makes the above optimization fairly tough in follow.

A problem for deep learning-based world fashions: sensitivity of state-input gradients

The tl;dr of this part is: immediately optimizing states by means of a deep learning-based F_{theta} is extremely brittle, à la adversarial robustness. Even in case you prepare your world mannequin in a lower-dimensional state house, the coaching course of for the world mannequin makes unseen state landscapes very sharp, whether or not it’s an unseen state itself or just a traditional/orthogonal route to the information manifold.

Adversarial robustness and the “dimpled manifold” mannequin

Adversarial robustness initially checked out classification fashions f_theta : mathbb{R}^{wtimes h times c} to mathbb{R}^K, and confirmed that by following the gradient of a specific logit nabla f_theta^k from a base picture x (not of sophistication k), you didn’t have to maneuver far alongside x' = x + epsilonnabla f_theta^k to make f_theta classify x' as k (Szegedy et al., 2014; Goodfellow et al., 2015):

Adversarial example
Depiction of the traditional instance from (Goodfellow et al., 2015).

Later work has painted a geometrical image for what’s happening: for information close to a low-dimensional manifold mathcal{M}, the coaching course of controls conduct in tangential instructions, however doesn’t regularize conduct in orthogonal instructions, thus resulting in delicate conduct (Stutz et al., 2019). One other method acknowledged: f_theta has an inexpensive Lipschitz fixed when contemplating solely tangential instructions to the information manifold mathcal{M}, however can have very excessive Lipschitz constants in regular instructions. The truth is, it usually advantages the mannequin to be sharper in these regular instructions, so it will possibly match extra sophisticated features extra exactly.

Adversarial perturbations leave the data manifold

Because of this, such adversarial examples are extremely widespread even for a single given mannequin. Additional, this isn’t simply a pc imaginative and prescient phenomenon; adversarial examples additionally seem in LLMs (Wallace et al., 2019) and in RL (Gleave et al., 2019).

Whereas there are strategies to coach for extra adversarially strong fashions, there’s a identified trade-off between mannequin efficiency and adversarial robustness (Tsipras et al., 2019): particularly within the presence of many weakly-correlated variables, the mannequin should be sharper to attain increased efficiency. Certainly, most fashionable coaching algorithms, whether or not in pc imaginative and prescient or LLMs, don’t prepare adversarial robustness out. Thus, at the least till deep studying sees a significant regime change, this can be a downside we’re caught with.

Why is adversarial robustness a problem for world mannequin planning?

Contemplate a single element of the dynamics loss we’re optimizing within the lifted state method:

    [min_{s_t, a_t, s_{t+1}} |F_theta(s_t, a_t) - s_{t+1}|_2^2]

Let’s additional deal with simply the bottom state:

    [min_{s_t} |F_theta(s_t, a_t) - s_{t+1}|_2^2.]

Since world fashions are usually skilled on state/motion trajectories (s_1, a_1, s_2, a_2, ldots), the state-data manifold for F_{theta} has dimensionality bounded by the motion house:

    [mathrm{dim}(mathcal{M}_s) le mathrm{dim}(mathcal{A}) + 1 + mathrm{dim}(mathcal{R}),]

the place mathcal{R} is a few non-obligatory house of augmentations (e.g. translations/rotations). Thus, we will usually count on mathrm{dim}(mathcal{M}_s) to be a lot decrease than mathrm{dim}(mathcal{S}), and thus: it is vitally straightforward to search out adversarial examples that hack any state to every other desired state.

Because of this, the dynamics optimization

    [sum_{t=0}^{T-1} big|F_theta(s_t,a_t) - s_{t+1}big|_2^2]

feels extremely “sticky,” as the bottom factors s_t can simply trick F_{theta} into pondering it’s already made its native aim.1

Adversarial world model example


1. This adversarial robustness concern, whereas significantly unhealthy for lifted-state approaches, isn’t distinctive to them. Even for serial optimization strategies that optimize by means of the complete rollout map mathcal{F}^T, it’s doable to get into unseen states, the place it is vitally straightforward to have a traditional element fed into the delicate regular parts of D_s F_{theta}. The motion Jacobian’s chain rule growth is

    [Bigl(prod_{t=1}^T D_s F_theta(s_t, a_t)Bigr) D_{a_0}F_theta(s_0, a_0).]

See what occurs if any stage of the product has any element regular to the information manifold.


Our repair

That is the place our new planner GRASP is available in. The principle commentary: whereas D_s F_{theta} is untrustworthy and adversarial, the motion house is normally low-dimensional and exhaustively skilled, so D_a F_{theta} is definitely cheap to optimize by means of and doesn’t undergo from the adversarial robustness concern!

Network diagram showing high-dim state vs low-dim action
The motion enter is normally lower-dimensional and densely skilled (the mannequin has seen each motion route), so motion gradients are significantly better behaved.

At its core, GRASP builds a first-order lifted state / collocation-based planner that’s solely depending on motion Jacobians by means of the world mannequin. We thus exploit the differentiability of discovered world fashions F_{theta}, whereas not falling sufferer to the inherent sensitivity of the state Jacobians D_s F_{theta}.

GRASP: Gradient RelAxed Stochastic Planner

As famous earlier than, we begin with the collocation planning goal, the place we raise the states and loosen up dynamics right into a penalty:

    [min_{mathbf{s},mathbf{a}} mathcal{L}(mathbf{s}, mathbf{a}) = sum_{t=0}^{T-1} big|F_theta(s_t,a_t) - s_{t+1}big|_2^2, quad text{with } s_0 text{ fixed and } s_T=g.]

We then make two key additions.

Ingredient 1: Exploration by noising the state iterates

Even with a smoother goal, planning is nonconvex. We introduce exploration by injecting Gaussian noise into the digital state updates throughout optimization.

A easy model:

    [s_t leftarrow s_t - eta_s nabla_{s_t}mathcal{L} + sigma_{text{state}} xi, qquad xisimmathcal{N}(0,I).]

Actions are nonetheless up to date by non-stochastic descent:

    [a_t leftarrow a_t - eta_a nabla_{a_t}mathcal{L}.]

The state noise helps you “hop” between basins within the lifted house, whereas the actions stay guided by gradients. We discovered that particularly noising states right here (versus actions) finds an excellent stability of exploration and the power to search out sharper minima.2


2. As a result of we solely noise the states (and never the actions), the corresponding dynamics are usually not really Langevin dynamics.


Ingredient 2: Reshape gradients: cease brittle state-input gradients, hold motion gradients

As mentioned, the delicate pathway is the gradient that flows into the state enter of the world mannequin, D_s F_{theta}. Probably the most simple method to do that initially is to simply cease state gradients into F_{theta} immediately:

  • Let bar{s}_t be the identical worth as s_t, however with gradients stopped.

Outline the stop-gradient dynamics loss:

    [mathcal{L}_{text{dyn}}^{text{sg}}(mathbf{s},mathbf{a}) = sum_{t=0}^{T-1} big|F_theta(bar{s}_t, a_t) - s_{t+1}big|_2^2.]

This alone doesn’t work. Discover now states solely comply with the earlier state’s step, with out something forcing the bottom states to chase the following ones. Because of this, there are trivial minima for simply stopping on the origin, then just for the ultimate motion making an attempt to get to the aim in a single step.

Dense aim shaping

We are able to view the above concern because the aim’s sign being lower off fully from earlier states. One strategy to repair that is to easily add a dense aim time period all through prediction:

    [mathcal{L}_{text{goal}}^{text{sg}}(mathbf{s},mathbf{a}) = sum_{t=0}^{T-1} big|F_theta(bar{s}_t, a_t) - gbig|_2^2.]

In regular settings this could over-bias in the direction of the grasping answer of straight chasing the aim, however that is balanced in our setting by the stop-gradient dynamics loss’s bias in the direction of possible dynamics. The ultimate goal is then as follows:

    [mathcal{L}(mathbf{s},mathbf{a}) = mathcal{L}_{text{dyn}}^{text{sg}}(mathbf{s},mathbf{a}) + gamma , mathcal{L}_{text{goal}}^{text{sg}}(mathbf{s},mathbf{a}).]

The result’s a planning optimization goal that doesn’t have dependence on state gradients.


Periodic “sync”: briefly return to true rollout gradients

The lifted stop-gradient goal is nice for quick, guided exploration, however it’s nonetheless an approximation of the unique serial rollout goal.

So each K_{text{sync}} iterations, GRASP does a brief refinement section:

  1. Roll out from s_0 utilizing present actions mathbf{a}, and take a couple of small gradient steps on the unique serial loss:

    [mathbf{a} leftarrow mathbf{a} - eta_{text{sync}},nabla_{mathbf{a}},|s_T(mathbf{a})-g|_2^2.]

The lifted-state optimization nonetheless offers the core of the optimization, whereas this refinement step provides some help to maintain states and actions grounded in the direction of actual trajectories. This refinement step can after all get replaced with a serial planner of your selection (e.g. CEM); the core concept is to nonetheless get a number of the good thing about the full-path synchronization of serial planners, whereas nonetheless largely utilizing the advantages of the lifted-state planning.


How GRASP addresses long-range planning

Collocation-based planners provide a pure repair for long-horizon planning, however this optimization is kind of tough by means of fashionable world fashions attributable to adversarial robustness points. GRASP proposes a easy answer for a smoother collocation-based planner, alongside steady stochasticity for exploration. Because of this, longer-horizon planning finally ends up not solely succeeding extra, but in addition discovering such successes sooner:

Push-T planning demo
Push-T demo: longer-horizon planning with GRASP.

Horizon CEM GD LatCo GRASP
H=40 61.4% / 35.3s 51.0% / 18.0s 15.0% / 598.0s 59.0% / 8.5s
H=50 30.2% / 96.2s 37.6% / 76.3s 4.2% / 1114.7s 43.4% / 15.2s
H=60 7.2% / 83.1s 16.4% / 146.5s 2.0% / 231.5s 26.2% / 49.1s
H=70 7.8% / 156.1s 12.0% / 103.1s 0.0% / — 16.0% / 79.9s
H=80 2.8% / 132.2s 6.4% / 161.3s 0.0% / — 10.4% / 58.9s

Push-T outcomes. Success price (%) / median time to success. Daring = greatest in row. Be aware the median success time will bias increased with increased success price; GRASP manages to be sooner regardless of increased success price.


What’s subsequent?

There may be nonetheless loads of work to be achieved for contemporary world mannequin planners. We need to exploit the gradient construction of discovered world fashions, and collocation (lifted-state optimization) is a pure method for long-horizon planning, however it’s essential to grasp typical gradient construction right here: easy and informative motion gradients and brittle state gradients. We view GRASP as an preliminary iteration for such planners.

Extension to diffusion-based world fashions (deeper latent timesteps might be considered as smoothed variations of the world mannequin itself), extra subtle optimizers and noising methods, and integrating GRASP into both a closed-loop system or RL coverage studying for adaptive long-horizon planning are all pure and attention-grabbing subsequent steps.

I do genuinely assume it’s an thrilling time to be engaged on world mannequin planners. It’s a humorous candy spot the place the background literature (planning and management general) is extremely mature and well-developed, however the present setting (pure planning optimization over fashionable, large-scale world fashions) remains to be closely underexplored. However, as soon as we work out all the proper concepts, world mannequin planners will probably develop into as commonplace as RL.


For extra particulars, learn the full paper or go to the mission web site.


Quotation

@article{psenka2026grasp,
  title={Parallel Stochastic Gradient-Primarily based Planning for World Fashions},
  writer={Michael Psenka and Michael Rabbat and Aditi Krishnapriyan and Yann LeCun and Amir Bar},
  12 months={2026},
  eprint={2602.00475},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2602.00475}
}

This text was initially revealed on the BAIR weblog, and seems right here with the authors’ permission.



Gradient-based planning for world fashions at longer horizons 1

BAIR Weblog
is the official weblog of the Berkeley Synthetic Intelligence Analysis (BAIR) Lab.

Gradient-based planning for world fashions at longer horizons 2


BAIR Weblog
is the official weblog of the Berkeley Synthetic Intelligence Analysis (BAIR) Lab.

Related Articles

LEAVE A REPLY

Please enter your comment!
Please enter your name here

Latest Articles