通过Unsloth微调Qwen2.5-VL实现复杂数学公式的OCR
nanshan 2025-05-08 20:15 20 浏览 0 评论
【学习目标】
- 理解Unsloth的核心优化原理与基础实践;
- 掌握基于Unsloth的高效微调工作流。
【知识储备】
1. Unsloth简介
Unsloth是一个专为大型语言模型(LLM)设计的微调框架,旨在提高微调效率并减少显存占用。 它通过手动推导计算密集型数学步骤并手写 GPU 内核,实现了无需硬件更改即可显著加快训练速度。
主要功能点:
- 高效微调:Unsloth通过深度优化,使 LLM 的微调速度提高 2-5 倍,显存使用量减少约 80%,且准确度无明显下降。
- 广泛的模型支持:目前支持的模型包括目前各类主流模型,用户可以根据需求适合的模型进行微
调。
- 兼容性:Unsloth与HuggingFace生态兼容,用户可以轻松将其与 traformers、peft、trl 等库结合,实现模型的全参微调(full)、监督微调(SFT)和广义强化学习优化(GRPO)、基于人类反馈的奖励建模(包括DPO、ORPO、KTO等方法)、持续预训练(continued pretraining)、文本补全(text completion)以及其他前沿训练方法。
- 内存优化: 通过 4 位和 16 位的 QLoRA/LoRA 微调,unsloth 显著了显存占用,使得在资源受限的环境中也能大的微调。
Unsloth核心优势:
- Unsloth简化了整个微调工作流程,包括模型加载、量化、训练、评估、运行、保存、导出,以及与推理引擎(如Ollama、llama.cpp和vLLM)的集成;
- Unsloth相比传统方法,Unsloth 能够在更短的时间内、更少的显存消耗完成微调任务,节省时间及硬件成本;
- Unsloth定期与Huggingface、Google和Meta团队合作,以修复LLM训练和模型中的错误(例如,之前有报告有为Gemma 3和Phi-4所做的错误排查工作)。因此,在使用Unsloth进行模型微调时能看到最准确的结果。
- 开源免费: Unsloth提供开源版本,用户可以在 Google Colab 或 Kaggle Notebooks 上免费试用,方便上手体验。
总的来说,unsloth 为大型语言模型的微调提供了高效、低成本的解决方案,适合希望在有限资源下进行模型微调的开发者和研究人员。
【任务实施】
1. 运行环境要求
1.1、硬件环境
序 | 名称 | 建议配置 |
1 | CPU | Intel I7 |
2 | 显卡 | NVIDIA GeForce RTX 4090 |
3 | 内存 | 16G |
4 | 系统 | Ubuntu20.04 + |
注:根据微调的模型参数及量化方法不同,显存要求也会不一样,参考值如下:
参数量 | QLoRA (4-bit) | LoRA (16-bit) |
3B | 3.5 GB | 8 GB |
7B | 5 GB | 19 GB |
8B | 6 GB | 22 GB |
9B | 6.5 GB | 24 GB |
11B | 7.5 GB | 29 GB |
14B | 8.5 GB | 33 GB |
27B | 22 GB | 64 GB |
32B | 26 GB | 76 GB |
40B | 30 GB | 96 GB |
70B | 41 GB | 164 GB |
81B | 48 GB | 192 GB |
90B | 53 GB | 212 GB |
405B | 237 GB | 950 GB |
1.2、软件环境
序 | 名称 | 版本 |
1 | Python | 3.10+ |
2 | CUDA | 12.1+ |
3 | JupyterLab | 3.5+ |
2. Unsloth安装
2.1、创建并配置虚拟环境
打开一个新的命令行终端,创建Conda新环境,名称可自定义,这里以"unsloth"为例:
$ conda create -n unsloth python=3.11 ipykernel -y
激活新建的环境:
$ conda activate unsloth
激活后,终端提示符通常会显示环境名称(unsloth),表示您已在该环境当中。
将unsloth虚拟环境加入到Jupyterlab的内核中,以便后续.ipynb文档可以选择该环境运行:
$ python -m ipykernel install --user --name=unsloth --display-name "unsloth"
运行后,点击右上角内核切换按钮,进行内核切换,查看是否有出现unsloth内核,如果没有请在菜单栏重启内核再操作:
2.2、Unsloth安装
In [ ]:
import sys
PYTHON_PATH=sys.executable
print(PYTHON_PATH)
In [ ]:
%%capture
!{PYTHON_PATH} -m pip install unsloth modelscope ipywidgets tensorboard
- %%capture:隐藏命令的输出,避免安装过程中的冗长日志刷屏。但注意观察右上角的运行状态,显示"忙碌",请耐心等待。
如果是开发环境,可以继续运行以下命令,从 GitHub 仓库安装Unsloth的最新开发版(可能包含未发布的修复或功能)。
In [ ]:
!{PYTHON_PATH} -m pip install \
--force-reinstall \
--no-cache-dir \
--no-deps \
git+https://github.com/unslothai/unsloth.git
2.3、验证Unsloth
运行以下命令查看Unsloth的安装情况 ,如果安装成功,会显示版本号等信息。
In [ ]:
!{PYTHON_PATH} -m pip show unsloth
3. 通过Unsloth进行Qwen2.5-VL模型推理
3.1、Qwen多模态模型下载
通过ModelScope SDK将Qwen2.5-VL多模态模型下载到指定目录,使用的是7B经过指令微调后的模型。
In [ ]:
import os
from modelscope import snapshot_download
# 定义基座模型以及模型存放目录
MODEL_NAME_OR_PATH = "models/Qwen2.5-VL-7B-Instruct"
BASE_MODEL = "unsloth/Qwen2.5-VL-7B-Instruct"
# 如目录不存在,则下载模型
if not os.path.exists(MODEL_NAME_OR_PATH):
snapshot_download(BASE_MODEL, local_dir=MODEL_NAME_OR_PATH)
# 目录已存在,打印文件列表
else:
print("模型已存在,跳过下载")
files = [item for item in os.listdir(MODEL_NAME_OR_PATH) if not item.startswith('.')]
for file in files:
print(file)
3.2. 导入相关依赖库
In [ ]:
from unsloth import FastVisionModel
import torch
from PIL import Image, ImageOps
from IPython.display import display
from transformers import TextStreamer
3.3、加载模型和分词器
In [ ]:
model, tokenizer = FastVisionModel.from_pretrained(
model_name=MODEL_NAME_OR_PATH,
max_seq_length=2048,
dtype=None,
load_in_4bit=True,
load_in_8bit=False,
full_finetuning=False,
)
3.4、微调前的模型推理
将推理过程封装成一个函数,方便后续多次调用,代码如下:
In [ ]:
def inference(text, image_file, system_prompt = None):
"""
推理函数
Args
text: 输入的文本
image_file: 图片文件路径
system_prompt: 系统提示语,默认为None
"""
# 显示图片
image = Image.open(image_file)
image = ImageOps.exif_transpose(image)
display(image)
# 将模型切换到推理模式(会关闭 dropout 等训练专用层,优化推理速度)
FastVisionModel.for_inference(model)
# 构造符合ChatML风格的输入消息
messages = []
if system_prompt:
messages.append({"role": "system", "content": [{"type":"text", "text": system_prompt}]})
messages = [
{"role": "user", "content":
[
{"type": "image"},
{"type": "text", "text": text}
]}
]
# 将messages转换为模型所需的对话格式字符串
input_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True)
# 图像会被编码为视觉特征向量,文本按正常分词流程处理
# 输出包含input_ids(文本)、pixel_values(图像)等键的字典
model_inputs = tokenizer(
text=input_text,
images=image,
padding=True,
add_special_tokens=False,
return_tensors="pt"
)
model_inputs = model_inputs.to(model.device)
# 通过TextStreamer实现流式输出
model.generate(
**model_inputs,
max_new_tokens=512,
use_cache = True,
temperature = 1.5,
min_p = 0.1,
streamer=TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True),
)
函数编写完成后,现在我们来先传递一个问题、一张图片给函数,看看Unsloth框架的模型推理效果。
In [ ]:
inference(text="图片表达了什么?" , image_file="assets/candy.jpg")
那么在进行微调前,我们首先验证下原模型Qwen2.5-VL,对含有数学公式的图片识别效果怎么样,同样调用上面的推理函数inference:
In [ ]:
inference(
text="为图片生成LaTeX表达式",
image_file="assets/demo_pic_1.jpg",
system_prompt="你是一个LaText OCR助手,目标是读取用户输入的照片,转换成LaTex公式"
)
观察模型生成的结果,将该结果通过LaTeX公式生成器验证下,看是否正确,同时记录起来供后续做对比。
对比两者,可以发现模型虽然在提示词的作用下,发挥作用,但是回答的并不正确。但下来我们需要对它进行微调,使其更适应处理复杂数学公式。
4. 通过Unsloth微调Qwen2.5-VL实现复杂数学公式的OCR
4.1、微调数据集的准备
通过huggingface datasets库下载数据集:
In [ ]:
from datasets import load_dataset
# 定义数据集名称及保存路径
dataset_name = "unsloth/LaTeX_OCR"
dataset_dir = "datasets/LaTeX_OCR"
exist = os.path.exists(dataset_dir)
dataset = load_dataset(dataset_name, split="train", cache_dir=dataset_dir )
if not exist:
print(f"数据集已下载保存到{dataset_dir},共 {len(dataset)} 条样本")
else:
print(f"数据集已存在,已从{dataset_dir}加载数据集")
数据集已存在,已从datasets/LaTeX_OCR加载数据集
让我们来简单了解一下这个数据集。我们看一看第三张图片是什么,以及对应的标题是什么。
In [3]:
dataset[2]["image"]
Out[3]:
In [4]:
dataset[2]["text"]
Out[4]:
'H ^ { \\prime } = \\beta N \\int d \\lambda \\biggl \\{ \\frac { 1 } { 2 \\beta ^ { 2 } N ^ { 2 } } \\partial _ { \\lambda } \\zeta ^ { \\dagger } \\partial _ { \\lambda } \\zeta + V ( \\lambda ) \\zeta ^ { \\dagger } \\zeta \\biggr \\} \\ .'
我们运行下一行代码,直接在JupyterLab中渲染上述dataset[2]["text"]的LaTeX表达式,看是否与图片一致:
In [5]:
from IPython.display import display, Math
latex = dataset[2]["text"]
display(Math(latex))
H′=βN∫dλ{12β2N2λζ+λζ+V(λ)ζ+ζ} .H′=βN∫dλ{12β2N2λζ+λζ+V(λ)ζ+ζ} .
可以发现与原图的公式一模型一样。
那么了解完数据集结构之后,我们需要将这些数据格式化成Qwen2.5-VL需要的Json格式(本质上所有视觉微调任务都是类似ChatML格式,ChatML格式仅仅是sharegpt格式的一种特殊情况),如下所示:
[
{ "role": "user",
"content": [{"type": "text", "text": Q}, {"type": "image", "image": image} ]
},
{ "role": "assistant",
"content": [{"type": "text", "text": A} ]
},
]
定义数据预处理函数data_process,目的是处理数据集的每条数据,将其格式化成Qwen2.5-VL需要的Json格式:
In [ ]:
instruction = "为图片生成LaTeX表达式"
def data_process(sample):
conversation = [
{ "role": "user", "content" : [
{"type" : "text", "text" : instruction},
{"type" : "image", "image" : sample["image"]} ]
},
{ "role" : "assistant",
"content" : [
{"type" : "text", "text" : sample["text"]} ]
},
]
return { "messages" : conversation }
调用数据处理预函数data_process,批量将所有数据格式化为微调输入格式,返回给新的变量converted_dataset:
In [ ]:
converted_dataset = [data_process(sample) for sample in dataset]
我们展示下经过格式化后的首条数据内容:
In [ ]:
converted_dataset[0]
4.2、LoRA微调配置
In [ ]:
model = FastVisionModel.get_peft_model(
model,
finetune_vision_layers = True,
finetune_language_layers = True,
finetune_attention_modules = True,
finetune_mlp_modules = True,
# target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
r = 16,
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing="unsloth",
random_state = 3407,
use_rslora = False,
loftq_config = None,
)
4.3、训练参数配置
In [ ]:
from trl import SFTTrainer, SFTConfig
from unsloth.trainer import UnslothVisionDataCollator
from unsloth import is_bf16_supported
from datetime import datetime
output_dir = f"outputs/exp_{datetime.now().strftime('%Y%m%d_%H%M')}"
# 将模型切换到训练模式
FastVisionModel.for_training(model)
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
data_collator = UnslothVisionDataCollator(model, tokenizer),
train_dataset = converted_dataset,
args = SFTConfig(
output_dir = output_dir,
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_steps = 5,
max_steps = 10,
# num_train_epochs = 2,
learning_rate = 2e-4,
fp16 = not is_bf16_supported(),
bf16 = is_bf16_supported(),
report_to = "tensorboard",
logging_steps = 5,
logging_dir=output_dir,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
remove_unused_columns = False,
dataset_text_field = "",
dataset_kwargs = {"skip_prepare_dataset": True},
dataset_num_proc = 4,
max_seq_length = 2048,
),
)
4.4、启动训练
打印当前GPU显存信息:
In [ ]:
# 获取索引为0的GPU设备的详细属性
gpu_stats = torch.cuda.get_device_properties(0)
# 返回PyTorch当前预留的显存峰值
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
# GPU的物理显存总量
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}.")
print(f"1)最大显存 = {max_memory} GB.")
print(f"2)预留 {start_gpu_memory} GB 的显存.")
调用train()开始训练:
In [ ]:
trainer_stats = trainer.train()
显示最终内存和时间统计:
In [ ]:
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"训练耗时:{trainer_stats.metrics['train_runtime']}秒.")
print(
f"训练耗时:{round(trainer_stats.metrics['train_runtime']/60, 2)}分钟."
)
print(f"峰值预留显存 = {used_memory} GB.")
print(f"LoRA训练专用显存峰值 = {used_memory_for_lora} GB.")
print(f"峰值预留显存占总显存比例 = {used_percentage} %.")
print(f"LoRA训练显存占总显存比例 = {lora_percentage} %.")
4.5、微调后结果分析
指定训练日志所在目录,调用tensorboard命令启动,会在--port指定的端口启动一个可视化WEB服务,在浏览器中打开 http://localhost:6006(如果是云服务器的话,根据IP或映射访问) 即可查看可视化结果。
In [ ]:
!tensorboard --logdir {output_dir} --port 6006
运行以上命令后,打开浏览器访问,如果损失率没有稳定下降,需要调整训练参数重新开始训练。
4.6、模型微调后的推理
现在开始运行微调后的模型,使用相同的推理函数、相同的图片以及提示词:
In [ ]:
inference(
text="为图片生成LaTeX表达式",
image_file="assets/demo_pic_1.jpg",
system_prompt="你是一个LaText OCR助手,目标是读取用户输入的照片,转换成LaTex公式"
)
将输出的结果拷贝到LaTeX公式生成器验证下:
继续与推理前的结果对比,可以发现经过微调后的模型,生成的结果更加接近、符合预期。但由于训练步数/轮次太少,因此生成的结果还并不能完全正确,感兴趣的大家可以继续增大训练轮次,但时间会久些。
4.7、保存微调模型
将最终模型保存为LoRA适配器,可以使用Huggingface的save_pretrained方法进行本地保存,同时也要把分词器保存。
In [ ]:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"LoRA权重文件已保存在:{output_dir}")
但上述代码只是保存了LoRA适配器,而不是完整的模型,通过以下代码保存为完整的float16精度模型。该精度的模型可以使用vLLM、transformers等工具进行加载推理。
In [ ]:
new_model_dir = "models/Qwen2.5-VL-7B-LaTeXOCR"
model.save_pretrained_merged(
new_model_dir,
tokenizer,
save_method="merged_16bit",
)
print(f"模型已合并并保存到:{new_model_dir}")
model.save_pretrained_merged方法会逐层检查基础模型,并去huggingface下载相应的基础模型,所以尽量开启HF国内镜像源或代理,不然会抵账,下载也需要点时间。
合并保存完成后,观察models/Qwen2.5-VL-7B-LaTeXOCR目录,生成了以下文件:
到此,我们使用Qwen2.5-VL多模态基座模型,通过Unsloth的QLoRA微调方法,成功训练了第一个模型,让其可以识别LaTeX公式。让你对Unsloth有个初始的认识,更多其它模型的训练方法,请继续往下实战。
相关推荐
- 0722-6.2.0-如何在RedHat7.2使用rpm安装CDH(无CM)
-
文档编写目的在前面的文档中,介绍了在有CM和无CM两种情况下使用rpm方式安装CDH5.10.0,本文档将介绍如何在无CM的情况下使用rpm方式安装CDH6.2.0,与之前安装C5进行对比。环境介绍:...
- ARM64 平台基于 openEuler + iSula 环境部署 Kubernetes
-
为什么要在arm64平台上部署Kubernetes,而且还是鲲鹏920的架构。说来话长。。。此处省略5000字。介绍下系统信息;o架构:鲲鹏920(Kunpeng920)oOS:ope...
- 生产环境starrocks 3.1存算一体集群部署
-
集群规划FE:节点主要负责元数据管理、客户端连接管理、查询计划和查询调度。>3节点。BE:节点负责数据存储和SQL执行。>3节点。CN:无存储功能能的BE。环境准备CPU检查JDK...
- 在CentOS上添加swap虚拟内存并设置优先级
-
现如今很多云服务器都会自己配置好虚拟内存,当然也有很多没有配置虚拟内存的,虚拟内存可以让我们的低配服务器使用更多的内存,可以减少很多硬件成本,比如我们运行很多服务的时候,内存常常会满,当配置了虚拟内存...
- 国产深度(deepin)操作系统优化指南
-
1.升级内核随着deepin版本的更新,会自动升级系统内核,但是我们依旧可以通过命令行手动升级内核,以获取更好的性能和更多的硬件支持。具体操作:-添加PPAs使用以下命令添加PPAs:```...
- postgresql-15.4 多节点主从(读写分离)
-
1、下载软件[root@TX-CN-PostgreSQL01-252software]#wgethttps://ftp.postgresql.org/pub/source/v15.4/postg...
- Docker 容器 Java 服务内存与 GC 优化实施方案
-
一、设置Docker容器内存限制(生产环境建议)1.查看宿主机可用内存bashfree-h#示例输出(假设宿主机剩余16GB可用内存)#Mem:64G...
- 虚拟内存设置、解决linux内存不够问题
-
虚拟内存设置(解决linux内存不够情况)背景介绍 Memory指机器物理内存,读写速度低于CPU一个量级,但是高于磁盘不止一个量级。所以,程序和数据如果在内存的话,会有非常快的读写速度。但是,内存...
- Elasticsearch性能调优(5):服务器配置选择
-
在选择elasticsearch服务器时,要尽可能地选择与当前业务量相匹配的服务器。如果服务器配置太低,则意味着需要更多的节点来满足需求,一个集群的节点太多时会增加集群管理的成本。如果服务器配置太高,...
- Es如何落地
-
一、配置准备节点类型CPU内存硬盘网络机器数操作系统data节点16C64G2000G本地SSD所有es同一可用区3(ecs)Centos7master节点2C8G200G云SSD所有es同一可用区...
- 针对Linux内存管理知识学习总结
-
现在的服务器大部分都是运行在Linux上面的,所以,作为一个程序员有必要简单地了解一下系统是如何运行的。对于内存部分需要知道:地址映射内存管理的方式缺页异常先来看一些基本的知识,在进程看来,内存分为内...
- MySQL进阶之性能优化
-
概述MySQL的性能优化,包括了服务器硬件优化、操作系统的优化、MySQL数据库配置优化、数据库表设计的优化、SQL语句优化等5个方面的优化。在进行优化之前,需要先掌握性能分析的思路和方法,找出问题,...
- Linux Cgroups(Control Groups)原理
-
LinuxCgroups(ControlGroups)是内核提供的资源分配、限制和监控机制,通过层级化进程分组实现资源的精细化控制。以下从核心原理、操作示例和版本演进三方面详细分析:一、核心原理与...
- linux 常用性能优化参数及理解
-
1.优化内核相关参数配置文件/etc/sysctl.conf配置方法直接将参数添加进文件每条一行.sysctl-a可以查看默认配置sysctl-p执行并检测是否有错误例如设置错了参数:[roo...
- 如何在 Linux 中使用 Sysctl 命令?
-
sysctl是一个用于配置和查询Linux内核参数的命令行工具。它通过与/proc/sys虚拟文件系统交互,允许用户在运行时动态修改内核参数。这些参数控制着系统的各种行为,包括网络设置、文件...
你 发表评论:
欢迎- 一周热门
-
-
UOS服务器操作系统防火墙设置(uos20关闭防火墙)
-
极空间如何无损移机,新Z4 Pro又有哪些升级?极空间Z4 Pro深度体验
-
手机如何设置与显示准确时间的详细指南
-
NAS:DS video/DS file/DS photo等群晖移动端APP远程访问的教程
-
如何在安装前及安装后修改黑群晖的Mac地址和Sn系列号
-
如何修复用户配置文件服务在 WINDOWS 上登录失败的问题
-
一加手机与电脑互传文件的便捷方法FileDash
-
日本海上自卫队的军衔制度(日本海上自卫队的军衔制度是什么)
-
10个免费文件中转服务站,分享文件简单方便,你知道几个?
-
爱折腾的特斯拉车主必看!手把手教你TESLAMATE的备份和恢复
-
- 最近发表
- 标签列表
-
- linux 查询端口号 (58)
- docker映射容器目录到宿主机 (66)
- 杀端口 (60)
- yum更换阿里源 (62)
- internet explorer 增强的安全配置已启用 (65)
- linux自动挂载 (56)
- 禁用selinux (55)
- sysv-rc-conf (69)
- ubuntu防火墙状态查看 (64)
- windows server 2022激活密钥 (56)
- 无法与服务器建立安全连接是什么意思 (74)
- 443/80端口被占用怎么解决 (56)
- ping无法访问目标主机怎么解决 (58)
- fdatasync (59)
- 405 not allowed (56)
- 免备案虚拟主机zxhost (55)
- linux根据pid查看进程 (60)
- dhcp工具 (62)
- mysql 1045 (57)
- 宝塔远程工具 (56)
- ssh服务器拒绝了密码 请再试一次 (56)
- ubuntu卸载docker (56)
- linux查看nginx状态 (63)
- tomcat 乱码 (76)
- 2008r2激活序列号 (65)