...

Coconut: A Framework for Latent Reasoning in LLMs


Paper link: https://arxiv.org/abs/2412.06769

Released: 9th of December 2024

Figure 1. The two reasoning modes of Coconut. In Language Mode (left), the model uses output text tokens as inputs for the next reasoning step. In Latent Mode (right), the model instead feeds its previous hidden state (the output of the last hidden layer) back into itself as input. Figure taken from [1]

a high focus on LLMs with reasoning capabilities, and for a good reason. Reasoning enhances the LLMs’ power to tackle complex issues, fosters stronger generalization, and introduces an interpretable layer that sheds light on a model’s internal thought process.

A Major milestone in LLM reasoning is the introduction of Chain-of-Thought Reasoning (CoT)[2], which proved that guiding models to reason step-by-step leads to significant improvements on arithmetic and symbolic reasoning tasks.

Despite their power, reasoning models still operate primarily within the confines of natural language, which can limit their effectiveness. Much of the token space is devoted to maintaining linguistic coherence rather than facilitating abstract reasoning. Addressing this limitation, an intriguing paper from Meta, Training Large Language Models to Reason in a Continuous Latent Space[1], proposes redeeming the chain of thought out of natural language entirely, only translating back to language when necessary.

Their contribution can be summarized in three key points:

  1. Chain of Continuous Thought (Coconut): An enhanced reasoning paradigm that builds on CoT. Instead of relying on the final text output, Coconut utilizes the model’s last embedding layer latent representations.
  2. An exploration of Coconut’s capabilities: indicating how multiple next steps in reasoning can be encoded simultaneously in the latent space.
  3. A deeper analysis of the latent reasoning process itself, so that we can understand Coconut’s internal representation of information.

Coconut, Simplified 

Before delving into the implementation details of Continuous Chain of Thought, it’s important to first establish some foundational grounds.

Given an input of sequence x = [x(1),x(2),x(3) … ,x(T)] , a Chain-Of-Thought LLM (M), which predicts the next token x(t+1) based on the sequence of previous tokens x(≤t) can be formally described as:

$$M_{CoT}(x_{t+1}|x

Where W is the weight matrix of our LLM, and x(t) is the input tokens at step t.

Coconut extends this formulation by removing the dependency on textual input tokens and instead using the model’s last hidden state h(t) as input. This adaptation modifies the LLM’s predictive function into:

$$M_{Coconut}(x_{t+1}|x

$$H_{t} = Transformer(E_{t})$$

Where E(t) = [e(x1), e(x2), … e(xt)] represents the sequence of token embeddings, with e(⋅) denoting the embedding function. H(t)​ captures the sequence of hidden states for all tokens up to position t.

This new formulation allows Coconut to operate in two distinct modes: Language Mode and Latent Mode, as illustrated in Figure 1 (left and right, respectively). In Language Mode, the model functions like a standard LLM, processing textual tokens as input, while in Latent mode, it operates on the internal hidden states instead.

Mode switching plays a critical role in Coconut’s training process. It not only enables the model to learn how to generate meaningful latent representations but also facilitates the decoding of these latent thoughts. Mode transitions are controlled using two special placeholder tokens:

(begin-of-thought) and (end-of-thought). Inserting at position i and at position j signals the model to operate in Latent Mode for tokens between positions i, and e(xj)= ).

$$E_{t}=[e_{x_{1}},e_{x_{2}},….,e_{x_{i}},h_{i},h_{i+1},..,h_{j-1},e_{x_{j}},e_{x_{j+1}},…,e_{x_{t}}]$$

Figure 2. Training process of Coconut, where at each training stage one language reasoning step is removed and replaced with c latent reasoning steps. Here, c is equal to 1. Figure taken from [1].

Inspired by [3], Coconut employs a multi-stage training curriculum. At each stage k, k language-based reasoning steps are replaced with L latent steps, where L=k⋅c, and c is a hyperparameter determining how many latent steps substitute a single language reasoning step. This progression is visualized in Figure 2, where at stage k=0, the model trains purely on standard CoT examples.

The author’s decision to apply multi-stage training is to decompose the training process into easier objectives, leading to better results. This pattern is already suggested and backed up in [3], where they proved that intermediately removing tokens enabled deeper internalization of reasoning.

Using latent thought enables end-to-end gradient-based training by replacing token-level transitions between reasoning steps with continuous hidden representations, as with this change, the network is fully differentiable. Beyond that, it also allows the model to encode multiple possible next steps concurrently, refining the reasoning path as it advances. A deeper exploration of this mechanism is provided in the Understanding Latent Reasoning section.

To illustrate, let’s examine a simple example drawn from GSM8K[4], one of the datasets used to train Coconut.

Question:

“Betty is saving money for a new wallet, which costs $100. Betty has only half of the money she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice as much as her parents. How much more money does Betty need to buy the wallet? “

Reasoning steps:

1.Betty has only 100 / 2 = $>50.

2.Betty’s grandparents gave her 15 * 2 = $>30.

3.This means, Betty needs 100–50–30–15 = $>5 more.

4. Answer: 5

This question is then incorporated into the training dataset and used across three distinct stages:

Figure 3. An example of the training process of Coconut. Figure by writer based on example taken from GSM8k[4].

As shown in Figure 3, at stage 0, no latent thoughts are present, only language-based reasoning steps followed by the final answer. In subsequent stages 1 and 2, one language reasoning step is progressively replaced by one latent thought (since c=1), until stage 3, where all reasoning steps are latent. This procedure is applied to each training example in the dataset.


Key Findings & Analysis

Three datasets were used to evaluate Coconut’s effectiveness. One focused on mathematical reasoning (GSM8K[4]) and two on logical reasoning: ProntoQA[5] and ProsQA. ProsQA (Proof with Search Question-Answering) is a modified version of ProntoQA, featuring randomly generated directed acyclic graphs (DAGs) of reasoning steps, designed to challenge the model with more complex planning tasks. All models were fine-tuned using GPT-2 as the base model, with c=1 for most datasets, except for GSM8K, where two latent thoughts were used (c=2).

Below is a simplified summary of the results reported in the paper:

Table 1. Accuracy results on three datasets. Results taken from [1].

The models used for comparison with the Coconut architecture are:

  • CoT: Model trained with Chain-of-Thought reasoning, utilizing full reasoning chains during training.
  • No-CoT: Model trained without any reasoning chains; standard language modeling without intermediate reasoning steps.
  • Coconut: The full implementation proposed in this paper.
  • w/o curriculum: The Coconut model trained without the multi-stage curriculum; i.e., no gradual introduction of latent thoughts.
  • w/o thought: Coconut with multi-stage training retained, but without introducing latent thoughts. Language reasoning steps are simply removed over stages instead.
  • Pause as thought [6]: Model trained without latent thoughts entirely, but special tokens are inserted in place of each removed thought. These tokens allow the model additional computation steps before generating an answer. Prior studies [7] have reported improved performance using this approach.

A close examination of the previous table reveals three key insights into the Coconut training paradigm.

First, latent reasoning demonstrates superior performance over Chain-of-Thought on logical reasoning tasks, outperforming it on benchmarks such as ProntoQA[5] and ProsQA. The substantial accuracy gain observed in ProsQA (97.0% vs 77.5%) highlights Coconut’s effectiveness in handling more complex reasoning challenges. Unfortunately, the authors didn’t explain the accuracy loss between CoT and Coconut (42.9% vs. 34.9%). This could be due to the mathematical nature of GSM8k, which, unlike ProsQA, requires less reasoning prowess.

Second, comparing Coconut with its non-multi-stage training counterpart, we reach the same findings suggested by [3]: breaking down the training process into simpler, more manageable tasks significantly enhances model performance. Furthermore, through comparing “w/o curriculum” with “w/o thought” implementation, it is clear that the effect of gradual multi-stage training is actually more prominent than just replacing language steps with latent thoughts in a single step. This is an interesting finding showing how crucial gradual training is to the final results.

Lastly, even when supplying the model with multi-stage training and enough computational capacity with the pause as thought model, the LLM still falls short compared to the main Coconut implementation. This is more apparent when comparing their GSM8K results, reinforcing the hypothesis that incorporating latent thoughts still boosts training effectiveness.


Understanding Latent Reasoning

One of the advantages of Coconut is that, unlike language-based thoughts, latent thoughts have the ability to consider several directions or outputs in their consideration. This leads to a different reasoning process than normal chaining, allowing us to interpret the reasoning process as a hypothetical tree search. Each depth layer is the result of a respective latent step k, and each node is a calculated probability of a specific option. This will be covered more in Example #2.

Two main examples of this phenomenon are presented in the paper. We will cover both of them briefly to illustrate the latent reasoning power of this new thought paradigm.

Example #1:

The first example demonstrates how a latent thought can contain multiple possible outcomes within its reasoning tree. To explore this, the continuous thought generated by the model was decoded using an LLM head, a process done solely for testing purposes, allowing us to probe the continuous thought and verify whether these latent thoughts were being learned correctly.

Question:

James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many meters does he run a week?

Reasoning Steps:

1. He runs 3*3=9 sprints a week

2. So he runs 9*60=540

Answer: 540

Alternative Solution:

1. He runs 3*60=180 meters a week

2. So he runs 3*180=540

When we decode the first latent thought generated by the model, we find that the top three possible outputs are:

1.”180” with a probability of 0.22

2.” 180” ( with a space) with prob. of 0.20

3.”90” with prob. of 0.13

This shows that the model is indeed considering the first step in the two viable solutions mentioned above.

Example #2:

The second example gives a clearer illustration of how the tree search is constructed as the number of thoughts increases, pruning older branches that are no longer relevant to the reasoning process and prioritizing more “sound” nodes.

Figure 4. Latent search tree for example #2. On the left are the results of decoding the first latent reasoning step, and on the right are the results of the second latent step. Figure taken from [1].

Question:

“Every grimpus is a yimpus. Every worpus is a jelpus. Every zhorpus is a sterpus. Every impus is a hilpus. Every jompus is a …grimpus is a gwompus. Every rempus is a gorpus. Alex is a sterpus. Every zhorpus is a rompus. Is Alex a gorpus or bompus?”

Reasoning Steps:

1.”Alex is a grimpus.”

2. “Every grimpus is a rorpus.”

3.”Every rorpus is a bompus.”

Answer: “Alex is a bompus.”

The probability for each option can be obtained through the multiplication of every token’s probability, as depicted in Figure 4. Here we show the state of the search tree after one latent thought (left), and after two (right).

We can see from the total calculated probabilities that in step one, the least probable option (0.01) is sterpus, while the second probable option is grimpus (0.32), which is the correct first step of reasoning in this case. When the search tree is updated with information from the second thought, the node for sterpus is completely disregarded, and the new node with the highest probability is rorpus, which is the correct second reasoning step.

This proves that Coconut has the power of including various next steps in its reasoning process, prioritizing more important steps as we go (similar to grimpus in step one) and disregarding less relevant ones (sterpus in step one). This shows that Coconut has the ability to navigate several thoughts in a tree manner, until it reaches its final conclusion.


Conclusion

In this post, we have discussed Coconut, a new reasoning paradigm elevating LLMs from the necessity of “thinking” in language space, and utilizing the latent space instead. We have discussed Coconut’s significant performance compared to other reasoning methods, covered the importance of multi-stage training, and given examples to prove and understand how the latent reasoning process works under the hood.

In my opinion, Coconut addresses an interesting research topic, sparking new exploration into latent reasoning approaches, paving the way for the creation of more sophisticated machine reasoning models that are not bound by language syntax.


References

[1] S. Hao, S. Sukhbaatar, D. Su, X. Li, Z. Hu, J. Weston and Y. Tian, Training Large Language Models to Reason in a Continuous Latent Space (2024), arXiv preprint arXiv:2412.06769

[2] J. Wei, X. Wang, D. Schuurmans, M. Bosma, B. Ichter, F. Xia, E. Chi, Q. Le and D. Zhou, Chain-of-Thought Prompting Elicits Reasoning in Large Language Models (2022), arXiv preprint arXiv:2201.11903

[3] Y. Deng, Y. Choi and S. Shieber, From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step (2024), arXiv preprint arXiv:2405.14838

[4] K. Cobbe, V. Kosaraju, M. Bavarian, M. Chen, H. Jun, L. Kaiser, M. Plappert, J. Tworek, J. Hilton, R. Nakano, C. Hesse and J. Schulman, Training Verifiers to Solve Math Word Problems (2021), arXiv preprint arXiv:2110.14168

[5] A. Saparov and H. He, Language Models Are Greedy Reasoners: A Systematic Formal Analysis of Chain-of-Thought (2022), arXiv preprint arXiv:2210.01240 

[6] S. Goyal, Z. Ji, A. S. Rawat, A. K. Menon, S. Kumar and V. Nagarajan, Think Before You Speak: Training Language Models With Pause Tokens (2024), arXiv preprint arXiv:2310.02226

[7] J. Pfau, W. Merrill and S. R. Bowman, Let’s Think Dot by Dot: Hidden Computation in Transformer Language Models (2024), arXiv preprint arXiv:2404.15758

Source link

#Coconut #Framework #Latent #Reasoning #inLLMs