普林斯顿Zlab发布LLM-Pruning Collection:基于JAX的大语言模型剪枝算法库

3 天前·来源:MarkTechPost
模型剪枝大语言模型JAX开源库AI优化

普林斯顿Zlab研究人员发布了LLM-Pruning Collection,这是一个基于JAX的开源库,整合了多种大语言模型剪枝算法。该库提供了统一的训练和评估框架,支持GPU和TPU,便于比较不同剪枝方法。它包含Minitron、ShortGPT、Wanda、SparseGPT、Magnitude、Sheared Llama和LLM-Pruner等算法实现。

普林斯顿Zlab研究人员发布了LLM-Pruning Collection,这是一个基于JAX的开源库,旨在将大语言模型的主要剪枝算法整合到一个可复现的框架中。该库的目标是简化在GPU和TPU上,使用一致的训练和评估栈来比较块级、层级和权重级剪枝方法。

LLM-Pruning Collection包含三个主要目录:pruning目录提供了多种剪枝方法的实现,包括Minitron、ShortGPT、Wanda、SparseGPT、Magnitude、Sheared Llama和LLM-Pruner。training目录集成了FMS-FSDP用于GPU训练和MaxText用于TPU训练。eval目录提供了基于lm-eval-harness的JAX兼容评估脚本,并支持MaxText加速,速度提升约2到4倍。

该库覆盖了不同粒度的剪枝算法家族。Minitron是NVIDIA开发的一种实用剪枝和蒸馏方法,可将Llama 3.1 8B和Mistral NeMo 12B压缩至4B和8B,同时保持性能。ShortGPT基于Transformer层冗余的观察,通过定义块影响力指标来移除低影响力层。Wanda、SparseGPT和Magnitude是后训练剪枝方法,其中Wanda通过权重幅度和输入激活的乘积来评分权重,SparseGPT使用二阶重建步骤,Magnitude是经典的基线方法。Sheared Llama是一种结构化剪枝方法,学习层、注意力头和隐藏维度的掩码,然后重新训练剪枝后的架构。LLM-Pruner是一个用于大语言模型结构化剪枝的框架,使用基于梯度的重要性分数移除非关键结构,并通过短LoRA调优阶段恢复性能。

背景阅读

大语言模型(LLM)如GPT系列、Llama等,通常包含数十亿甚至万亿参数,导致高计算和内存需求。模型剪枝是一种优化技术,通过移除冗余或不重要的参数来压缩模型,减少推理时间和资源消耗,同时尽可能保持性能。剪枝方法可分为结构化剪枝(如移除整个层或注意力头)和非结构化剪枝(如移除单个权重)。近年来,随着LLM的普及,剪枝算法快速发展,包括Minitron、ShortGPT、Wanda等,但缺乏统一框架进行比较和复现。JAX是一个由Google开发的机器学习库,以其高性能和可扩展性著称,特别适合在GPU和TPU上运行大规模模型。LLM-Pruning Collection的发布旨在填补这一空白,提供一个标准化的平台,促进剪枝技术的研究和应用。

评论 (0)

登录后参与评论

加载评论中...