普林斯顿Zlab研究人员发布了LLM-Pruning Collection,这是一个基于JAX的开源库,旨在将大语言模型的主要剪枝算法整合到一个可复现的框架中。该库的目标是简化在GPU和TPU上,使用一致的训练和评估栈来比较块级、层级和权重级剪枝方法。
LLM-Pruning Collection包含三个主要目录:pruning目录提供了多种剪枝方法的实现,包括Minitron、ShortGPT、Wanda、SparseGPT、Magnitude、Sheared Llama和LLM-Pruner。training目录集成了FMS-FSDP用于GPU训练和MaxText用于TPU训练。eval目录提供了基于lm-eval-harness的JAX兼容评估脚本,并支持MaxText加速,速度提升约2到4倍。
该库覆盖了不同粒度的剪枝算法家族。Minitron是NVIDIA开发的一种实用剪枝和蒸馏方法,可将Llama 3.1 8B和Mistral NeMo 12B压缩至4B和8B,同时保持性能。ShortGPT基于Transformer层冗余的观察,通过定义块影响力指标来移除低影响力层。Wanda、SparseGPT和Magnitude是后训练剪枝方法,其中Wanda通过权重幅度和输入激活的乘积来评分权重,SparseGPT使用二阶重建步骤,Magnitude是经典的基线方法。Sheared Llama是一种结构化剪枝方法,学习层、注意力头和隐藏维度的掩码,然后重新训练剪枝后的架构。LLM-Pruner是一个用于大语言模型结构化剪枝的框架,使用基于梯度的重要性分数移除非关键结构,并通过短LoRA调优阶段恢复性能。