This repository is a simplified and modular implementation of the knowledge tracing in “Locating and Editing Factual Associations in GPT, NeurlPS, 2022.”, providing a one-step solution for model analysis.
📆 More features are still in development (driven by personal interest, so they may be a bit varied and unsystematic 🤣).
Based on the implementation of hook functions (forward hook), we have defined multiple hook functions with different functionalities. By simply registering them to the corresponding layers of LLMs by name, we can automatically log intermediate states and make modifications or replacements to these intermediate hidden states.
Using the known_1000
dataset, which contains over 1,000 factual prompts for causal tracing and can be found in ./data
.
Pipeline:
- Clean run: compute and store the intermediate states during the model’s forward pass without intervention.
- Corrupted run: compute and store the intermediate states during the model’s forward pass after adding noise to the subject token’s embedding.
- Corrupted-with-restoration run: add noise to the subject token’s embedding, and during the model’s forward pass, restore intermediate states to those recorded from the clean run at each position. Record the final predicted probability values on the target token.
Mediator:
- Hidden states: perform a complete restoration of the hidden states computed at a specific layer.
- Hidden states with Attention severed: restore the hidden states of the input to a specific layer, but keep the output of the Attention layer in the corrupted run state.
- Hidden states with MLP severed: restore the hidden states of the input to a specific layer, but keep the output of the MLP layer in the corrupted run state.
Since different LLMs perform differently on the sample set, filtering is necessary to select an appropriate subset of samples for causal tracing.
Only retain questions that the model can answer correctly (where the answer is the token with the highest prediction probability) to reasonably locate and trace knowledge. This is because only when the model answers correctly can it indicate that relevant knowledge may have been recalled.
Create heatmaps for the average indirect effects obtained for the three types of mediators mentioned above.
P.S. The stored intermediate states and recorded prediction probabilities under different intervention methods are saved as binary files in the ./results
folder. These files are read during causal tracing and heatmap generation, facilitating reuse.
P.P.S. In our current implementation, the positions within the sentence include: the first subject token, the last subject token, the "further tokens" (all tokens from after the subject token to just before the last token), and the last token.
The causal tracing heatmaps for
LLaMA-3.2-3B
andLLaMA-3.1-8B
are shown in the figure below, consistent with the analysis results obtained in the original paper.
🦙 LLaMA-3.1-8B:
🦙 LLaMA-3.2-3B:
All dependencies are already listed in requirements.txt
; simply run the following command in the terminal:
pip install -r requirements.txt
We provided a quick-start demo for the llama model, with the entire configuration transmitted via a .yaml
file, which can be found and customized in the ./hparams
folder. Since layer naming differs between models, you’ll need to locate the correct layer names when customizing.
The layer IDs are represented by the [LAYER_NUM]
placeholder within the configuration file, which will be automatically recognized and replaced with the corresponding layer IDs during computation.
Once the model configuration is complete, you can enter the following command in the terminal to easily perform the above knowledge tracing process.
P.S. Before running, don’t forget to modify the path to the config file in the run.sh
file.
bash run.sh