通过Unsloth微调Qwen2.5-VL实现复杂数学公式的OCR
nanshan 2025-05-08 20:15 8 浏览 0 评论
【学习目标】
- 理解Unsloth的核心优化原理与基础实践;
- 掌握基于Unsloth的高效微调工作流。
【知识储备】
1. Unsloth简介
Unsloth是一个专为大型语言模型(LLM)设计的微调框架,旨在提高微调效率并减少显存占用。 它通过手动推导计算密集型数学步骤并手写 GPU 内核,实现了无需硬件更改即可显著加快训练速度。
主要功能点:
- 高效微调:Unsloth通过深度优化,使 LLM 的微调速度提高 2-5 倍,显存使用量减少约 80%,且准确度无明显下降。
- 广泛的模型支持:目前支持的模型包括目前各类主流模型,用户可以根据需求适合的模型进行微
调。
- 兼容性:Unsloth与HuggingFace生态兼容,用户可以轻松将其与 traformers、peft、trl 等库结合,实现模型的全参微调(full)、监督微调(SFT)和广义强化学习优化(GRPO)、基于人类反馈的奖励建模(包括DPO、ORPO、KTO等方法)、持续预训练(continued pretraining)、文本补全(text completion)以及其他前沿训练方法。
- 内存优化: 通过 4 位和 16 位的 QLoRA/LoRA 微调,unsloth 显著了显存占用,使得在资源受限的环境中也能大的微调。
Unsloth核心优势:
- Unsloth简化了整个微调工作流程,包括模型加载、量化、训练、评估、运行、保存、导出,以及与推理引擎(如Ollama、llama.cpp和vLLM)的集成;
- Unsloth相比传统方法,Unsloth 能够在更短的时间内、更少的显存消耗完成微调任务,节省时间及硬件成本;
- Unsloth定期与Huggingface、Google和Meta团队合作,以修复LLM训练和模型中的错误(例如,之前有报告有为Gemma 3和Phi-4所做的错误排查工作)。因此,在使用Unsloth进行模型微调时能看到最准确的结果。
- 开源免费: Unsloth提供开源版本,用户可以在 Google Colab 或 Kaggle Notebooks 上免费试用,方便上手体验。
总的来说,unsloth 为大型语言模型的微调提供了高效、低成本的解决方案,适合希望在有限资源下进行模型微调的开发者和研究人员。
【任务实施】
1. 运行环境要求
1.1、硬件环境
序 | 名称 | 建议配置 |
1 | CPU | Intel I7 |
2 | 显卡 | NVIDIA GeForce RTX 4090 |
3 | 内存 | 16G |
4 | 系统 | Ubuntu20.04 + |
注:根据微调的模型参数及量化方法不同,显存要求也会不一样,参考值如下:
参数量 | QLoRA (4-bit) | LoRA (16-bit) |
3B | 3.5 GB | 8 GB |
7B | 5 GB | 19 GB |
8B | 6 GB | 22 GB |
9B | 6.5 GB | 24 GB |
11B | 7.5 GB | 29 GB |
14B | 8.5 GB | 33 GB |
27B | 22 GB | 64 GB |
32B | 26 GB | 76 GB |
40B | 30 GB | 96 GB |
70B | 41 GB | 164 GB |
81B | 48 GB | 192 GB |
90B | 53 GB | 212 GB |
405B | 237 GB | 950 GB |
1.2、软件环境
序 | 名称 | 版本 |
1 | Python | 3.10+ |
2 | CUDA | 12.1+ |
3 | JupyterLab | 3.5+ |
2. Unsloth安装
2.1、创建并配置虚拟环境
打开一个新的命令行终端,创建Conda新环境,名称可自定义,这里以"unsloth"为例:
$ conda create -n unsloth python=3.11 ipykernel -y
激活新建的环境:
$ conda activate unsloth
激活后,终端提示符通常会显示环境名称(unsloth),表示您已在该环境当中。
将unsloth虚拟环境加入到Jupyterlab的内核中,以便后续.ipynb文档可以选择该环境运行:
$ python -m ipykernel install --user --name=unsloth --display-name "unsloth"
运行后,点击右上角内核切换按钮,进行内核切换,查看是否有出现unsloth内核,如果没有请在菜单栏重启内核再操作:
2.2、Unsloth安装
In [ ]:
import sys
PYTHON_PATH=sys.executable
print(PYTHON_PATH)
In [ ]:
%%capture
!{PYTHON_PATH} -m pip install unsloth modelscope ipywidgets tensorboard
- %%capture:隐藏命令的输出,避免安装过程中的冗长日志刷屏。但注意观察右上角的运行状态,显示"忙碌",请耐心等待。
如果是开发环境,可以继续运行以下命令,从 GitHub 仓库安装Unsloth的最新开发版(可能包含未发布的修复或功能)。
In [ ]:
!{PYTHON_PATH} -m pip install \
--force-reinstall \
--no-cache-dir \
--no-deps \
git+https://github.com/unslothai/unsloth.git
2.3、验证Unsloth
运行以下命令查看Unsloth的安装情况 ,如果安装成功,会显示版本号等信息。
In [ ]:
!{PYTHON_PATH} -m pip show unsloth
3. 通过Unsloth进行Qwen2.5-VL模型推理
3.1、Qwen多模态模型下载
通过ModelScope SDK将Qwen2.5-VL多模态模型下载到指定目录,使用的是7B经过指令微调后的模型。
In [ ]:
import os
from modelscope import snapshot_download
# 定义基座模型以及模型存放目录
MODEL_NAME_OR_PATH = "models/Qwen2.5-VL-7B-Instruct"
BASE_MODEL = "unsloth/Qwen2.5-VL-7B-Instruct"
# 如目录不存在,则下载模型
if not os.path.exists(MODEL_NAME_OR_PATH):
snapshot_download(BASE_MODEL, local_dir=MODEL_NAME_OR_PATH)
# 目录已存在,打印文件列表
else:
print("模型已存在,跳过下载")
files = [item for item in os.listdir(MODEL_NAME_OR_PATH) if not item.startswith('.')]
for file in files:
print(file)
3.2. 导入相关依赖库
In [ ]:
from unsloth import FastVisionModel
import torch
from PIL import Image, ImageOps
from IPython.display import display
from transformers import TextStreamer
3.3、加载模型和分词器
In [ ]:
model, tokenizer = FastVisionModel.from_pretrained(
model_name=MODEL_NAME_OR_PATH,
max_seq_length=2048,
dtype=None,
load_in_4bit=True,
load_in_8bit=False,
full_finetuning=False,
)
3.4、微调前的模型推理
将推理过程封装成一个函数,方便后续多次调用,代码如下:
In [ ]:
def inference(text, image_file, system_prompt = None):
"""
推理函数
Args
text: 输入的文本
image_file: 图片文件路径
system_prompt: 系统提示语,默认为None
"""
# 显示图片
image = Image.open(image_file)
image = ImageOps.exif_transpose(image)
display(image)
# 将模型切换到推理模式(会关闭 dropout 等训练专用层,优化推理速度)
FastVisionModel.for_inference(model)
# 构造符合ChatML风格的输入消息
messages = []
if system_prompt:
messages.append({"role": "system", "content": [{"type":"text", "text": system_prompt}]})
messages = [
{"role": "user", "content":
[
{"type": "image"},
{"type": "text", "text": text}
]}
]
# 将messages转换为模型所需的对话格式字符串
input_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True)
# 图像会被编码为视觉特征向量,文本按正常分词流程处理
# 输出包含input_ids(文本)、pixel_values(图像)等键的字典
model_inputs = tokenizer(
text=input_text,
images=image,
padding=True,
add_special_tokens=False,
return_tensors="pt"
)
model_inputs = model_inputs.to(model.device)
# 通过TextStreamer实现流式输出
model.generate(
**model_inputs,
max_new_tokens=512,
use_cache = True,
temperature = 1.5,
min_p = 0.1,
streamer=TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True),
)
函数编写完成后,现在我们来先传递一个问题、一张图片给函数,看看Unsloth框架的模型推理效果。
In [ ]:
inference(text="图片表达了什么?" , image_file="assets/candy.jpg")
那么在进行微调前,我们首先验证下原模型Qwen2.5-VL,对含有数学公式的图片识别效果怎么样,同样调用上面的推理函数inference:
In [ ]:
inference(
text="为图片生成LaTeX表达式",
image_file="assets/demo_pic_1.jpg",
system_prompt="你是一个LaText OCR助手,目标是读取用户输入的照片,转换成LaTex公式"
)
观察模型生成的结果,将该结果通过LaTeX公式生成器验证下,看是否正确,同时记录起来供后续做对比。
对比两者,可以发现模型虽然在提示词的作用下,发挥作用,但是回答的并不正确。但下来我们需要对它进行微调,使其更适应处理复杂数学公式。
4. 通过Unsloth微调Qwen2.5-VL实现复杂数学公式的OCR
4.1、微调数据集的准备
通过huggingface datasets库下载数据集:
In [ ]:
from datasets import load_dataset
# 定义数据集名称及保存路径
dataset_name = "unsloth/LaTeX_OCR"
dataset_dir = "datasets/LaTeX_OCR"
exist = os.path.exists(dataset_dir)
dataset = load_dataset(dataset_name, split="train", cache_dir=dataset_dir )
if not exist:
print(f"数据集已下载保存到{dataset_dir},共 {len(dataset)} 条样本")
else:
print(f"数据集已存在,已从{dataset_dir}加载数据集")
数据集已存在,已从datasets/LaTeX_OCR加载数据集
让我们来简单了解一下这个数据集。我们看一看第三张图片是什么,以及对应的标题是什么。
In [3]:
dataset[2]["image"]
Out[3]:
In [4]:
dataset[2]["text"]
Out[4]:
'H ^ { \\prime } = \\beta N \\int d \\lambda \\biggl \\{ \\frac { 1 } { 2 \\beta ^ { 2 } N ^ { 2 } } \\partial _ { \\lambda } \\zeta ^ { \\dagger } \\partial _ { \\lambda } \\zeta + V ( \\lambda ) \\zeta ^ { \\dagger } \\zeta \\biggr \\} \\ .'
我们运行下一行代码,直接在JupyterLab中渲染上述dataset[2]["text"]的LaTeX表达式,看是否与图片一致:
In [5]:
from IPython.display import display, Math
latex = dataset[2]["text"]
display(Math(latex))
H′=βN∫dλ{12β2N2λζ+λζ+V(λ)ζ+ζ} .H′=βN∫dλ{12β2N2λζ+λζ+V(λ)ζ+ζ} .
可以发现与原图的公式一模型一样。
那么了解完数据集结构之后,我们需要将这些数据格式化成Qwen2.5-VL需要的Json格式(本质上所有视觉微调任务都是类似ChatML格式,ChatML格式仅仅是sharegpt格式的一种特殊情况),如下所示:
[
{ "role": "user",
"content": [{"type": "text", "text": Q}, {"type": "image", "image": image} ]
},
{ "role": "assistant",
"content": [{"type": "text", "text": A} ]
},
]
定义数据预处理函数data_process,目的是处理数据集的每条数据,将其格式化成Qwen2.5-VL需要的Json格式:
In [ ]:
instruction = "为图片生成LaTeX表达式"
def data_process(sample):
conversation = [
{ "role": "user", "content" : [
{"type" : "text", "text" : instruction},
{"type" : "image", "image" : sample["image"]} ]
},
{ "role" : "assistant",
"content" : [
{"type" : "text", "text" : sample["text"]} ]
},
]
return { "messages" : conversation }
调用数据处理预函数data_process,批量将所有数据格式化为微调输入格式,返回给新的变量converted_dataset:
In [ ]:
converted_dataset = [data_process(sample) for sample in dataset]
我们展示下经过格式化后的首条数据内容:
In [ ]:
converted_dataset[0]
4.2、LoRA微调配置
In [ ]:
model = FastVisionModel.get_peft_model(
model,
finetune_vision_layers = True,
finetune_language_layers = True,
finetune_attention_modules = True,
finetune_mlp_modules = True,
# target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
r = 16,
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
use_gradient_checkpointing="unsloth",
random_state = 3407,
use_rslora = False,
loftq_config = None,
)
4.3、训练参数配置
In [ ]:
from trl import SFTTrainer, SFTConfig
from unsloth.trainer import UnslothVisionDataCollator
from unsloth import is_bf16_supported
from datetime import datetime
output_dir = f"outputs/exp_{datetime.now().strftime('%Y%m%d_%H%M')}"
# 将模型切换到训练模式
FastVisionModel.for_training(model)
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
data_collator = UnslothVisionDataCollator(model, tokenizer),
train_dataset = converted_dataset,
args = SFTConfig(
output_dir = output_dir,
per_device_train_batch_size = 2,
gradient_accumulation_steps = 4,
warmup_steps = 5,
max_steps = 10,
# num_train_epochs = 2,
learning_rate = 2e-4,
fp16 = not is_bf16_supported(),
bf16 = is_bf16_supported(),
report_to = "tensorboard",
logging_steps = 5,
logging_dir=output_dir,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
remove_unused_columns = False,
dataset_text_field = "",
dataset_kwargs = {"skip_prepare_dataset": True},
dataset_num_proc = 4,
max_seq_length = 2048,
),
)
4.4、启动训练
打印当前GPU显存信息:
In [ ]:
# 获取索引为0的GPU设备的详细属性
gpu_stats = torch.cuda.get_device_properties(0)
# 返回PyTorch当前预留的显存峰值
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
# GPU的物理显存总量
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}.")
print(f"1)最大显存 = {max_memory} GB.")
print(f"2)预留 {start_gpu_memory} GB 的显存.")
调用train()开始训练:
In [ ]:
trainer_stats = trainer.train()
显示最终内存和时间统计:
In [ ]:
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"训练耗时:{trainer_stats.metrics['train_runtime']}秒.")
print(
f"训练耗时:{round(trainer_stats.metrics['train_runtime']/60, 2)}分钟."
)
print(f"峰值预留显存 = {used_memory} GB.")
print(f"LoRA训练专用显存峰值 = {used_memory_for_lora} GB.")
print(f"峰值预留显存占总显存比例 = {used_percentage} %.")
print(f"LoRA训练显存占总显存比例 = {lora_percentage} %.")
4.5、微调后结果分析
指定训练日志所在目录,调用tensorboard命令启动,会在--port指定的端口启动一个可视化WEB服务,在浏览器中打开 http://localhost:6006(如果是云服务器的话,根据IP或映射访问) 即可查看可视化结果。
In [ ]:
!tensorboard --logdir {output_dir} --port 6006
运行以上命令后,打开浏览器访问,如果损失率没有稳定下降,需要调整训练参数重新开始训练。
4.6、模型微调后的推理
现在开始运行微调后的模型,使用相同的推理函数、相同的图片以及提示词:
In [ ]:
inference(
text="为图片生成LaTeX表达式",
image_file="assets/demo_pic_1.jpg",
system_prompt="你是一个LaText OCR助手,目标是读取用户输入的照片,转换成LaTex公式"
)
将输出的结果拷贝到LaTeX公式生成器验证下:
继续与推理前的结果对比,可以发现经过微调后的模型,生成的结果更加接近、符合预期。但由于训练步数/轮次太少,因此生成的结果还并不能完全正确,感兴趣的大家可以继续增大训练轮次,但时间会久些。
4.7、保存微调模型
将最终模型保存为LoRA适配器,可以使用Huggingface的save_pretrained方法进行本地保存,同时也要把分词器保存。
In [ ]:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"LoRA权重文件已保存在:{output_dir}")
但上述代码只是保存了LoRA适配器,而不是完整的模型,通过以下代码保存为完整的float16精度模型。该精度的模型可以使用vLLM、transformers等工具进行加载推理。
In [ ]:
new_model_dir = "models/Qwen2.5-VL-7B-LaTeXOCR"
model.save_pretrained_merged(
new_model_dir,
tokenizer,
save_method="merged_16bit",
)
print(f"模型已合并并保存到:{new_model_dir}")
model.save_pretrained_merged方法会逐层检查基础模型,并去huggingface下载相应的基础模型,所以尽量开启HF国内镜像源或代理,不然会抵账,下载也需要点时间。
合并保存完成后,观察models/Qwen2.5-VL-7B-LaTeXOCR目录,生成了以下文件:
到此,我们使用Qwen2.5-VL多模态基座模型,通过Unsloth的QLoRA微调方法,成功训练了第一个模型,让其可以识别LaTeX公式。让你对Unsloth有个初始的认识,更多其它模型的训练方法,请继续往下实战。
相关推荐
- F5负载均衡器如何通过irules实现应用的灵活转发?
-
F5是非常强大的商业负载均衡器。除了处理性能强劲,以及高稳定性之外,F5还可以通过irules编写强大灵活的转发规则,实现web业务的灵活应用。irules是基于TCL语法的,每个iRules必须包含...
- 映射域名到NAS
-
前面介绍已经将域名映射到家庭路由器上,现在只需要在路由器上设置一下端口转发即可。假设NAS在内网的IP是192.168.1.100,NAS管理端口2000.你的域名是www.xxx.com,配置外部端...
- 转发(Forward)和重定向(Redirect)的区别
-
转发是服务器行为,重定向是客户端行为。转发(Forward)通过RequestDispatcher对象的forward(HttpServletRequestrequest,HttpServletRe...
- SpringBoot应用中使用拦截器实现路由转发
-
1、背景项目中有一个SpringBoot开发的微服务,经过业务多年的演进,代码已经累积到令人恐怖的规模,亟需重构,将之拆解成多个微服务。该微服务的接口庞大,调用关系非常复杂,且实施重构的人员大部分不是...
- 公司想搭建个网站,网站如何进行域名解析?
-
域名解析是将域名指向网站空间IP,让人们通过注册的域名可以方便地访问到网站的一种服务。IP地址是网络上标识站点的数字地址,为方便记忆,采用域名来代替IP地址标识站点地址。域名解析就是域名到IP地址的转...
- 域名和IP地址什么关系?如何通过域名解析IP?
-
一般情况下,访客通过域名和IP地址都能访问到网站,那么两者之间有什么关系吗?本文中科三方针对域名和IP地址的关系和区别,以及如何实现域名与IP的绑定做下介绍。域名与IP地址之间的关系IP地址是计算机的...
- 分享网站域名301重定向的知识
-
网站域名做301重定向操作时,一般需要由专业的技术来协助完成,如果用户自己在维护,可以按照相应的说明进行操作。好了,下面说说重点,域名301重定向的操作步骤。首先,根据HTTP协议,在客户端向服务器发...
- NAS外网到底安全吗?一文看懂HTTP/HTTPS和SSL证书
-
本内容来源于@什么值得买APP,观点仅代表作者本人|作者:可爱的小cherry搭好了NAS,但是不懂做好网络加密,那么隐私泄露也会随时发生!大家好,这里是Cherry,喜爱折腾、玩数码,热衷于分享数...
- ForwardEmail免费、开源、加密的邮件转发服务
-
ForwardEmail是一款免费、加密和开源的邮件转发服务,设置简单只需4步即可正常使用,通过测试来看也要比ImprovMX好得多,转发近乎秒到且未进入垃圾箱(仅以Mailbox.org发送、Out...
- 使用CloudFlare进行域名重定向
-
当网站变更域名的时候,经常会使用域名重定向的方式,将老域名指向到新域名,这通常叫做:URL转发(URLFORWARDING),善于使用URL转发,对SEO来说非常有用,因为用这种方式能明确告知搜索引...
- 要将端口5002和5003通过Nginx代理到一个域名上的操作笔记
-
要将端口5002和5003通过Nginx代理到域名www.4rvi.cn的不同路径下,请按照以下步骤配置Nginx:步骤说明创建或编辑Nginx配置文件通常配置文件位于/etc/nginx/sites...
- SEO浅谈:网站域名重定向的三种方式
-
在大多数情况下,我们输入网站访问网站的时候,很难发现www.***.com和***.com的区别,因为一般的网站主,都会把这两个域名指向到同一网站。但是对于网站运营和优化来说,www.***.com和...
- 花生壳出现诊断域名与转发服务器ip不一致的解决办法
-
出现诊断域名与转发服务器ip不一致您可以:1、更改客户端所处主机的drs为223.5.5.5备用dns为119.29.29.29;2、在windows上进入命令提示符输入ipconfig/flush...
- 涨知识了!带你认识什么是域名
-
1、什么是域名从技术角度来看,域名是在Internet上解决IP地址对应的一种方法。一个完整的域名由两个或两个以上部分组成,各部分之间用英文的句号“.”来分隔。如“abc.com”。其中“com”称...
- 域名被跳转到其他网站是怎么回事
-
当你输入域名时被跳转到另一个网站,这可能是由几种原因造成的:一、域名可能配置了域名转发服务。无论何时有人访问域名,比如.com、.top等,都会自动重定向到另一个指定的URL,这通常是在域名注册商设...
你 发表评论:
欢迎- 一周热门
-
-
爱折腾的特斯拉车主必看!手把手教你TESLAMATE的备份和恢复
-
如何在安装前及安装后修改黑群晖的Mac地址和Sn系列号
-
[常用工具] OpenCV_contrib库在windows下编译使用指南
-
WindowsServer2022|配置NTP服务器的命令
-
Ubuntu系统Daphne + Nginx + supervisor部署Django项目
-
WIN11 安装配置 linux 子系统 Ubuntu 图形界面 桌面系统
-
解决Linux终端中“-bash: nano: command not found”问题
-
NBA 2K25虚拟内存不足/爆内存/内存占用100% 一文速解
-
Linux 中的文件描述符是什么?(linux 打开文件表 文件描述符)
-
K3s禁用Service Load Balancer,解决获取浏览器IP不正确问题
-
- 最近发表
- 标签列表
-
- linux 查询端口号 (58)
- docker映射容器目录到宿主机 (66)
- 杀端口 (60)
- yum更换阿里源 (62)
- internet explorer 增强的安全配置已启用 (65)
- linux自动挂载 (56)
- 禁用selinux (55)
- sysv-rc-conf (69)
- ubuntu防火墙状态查看 (64)
- windows server 2022激活密钥 (56)
- 无法与服务器建立安全连接是什么意思 (74)
- 443/80端口被占用怎么解决 (56)
- ping无法访问目标主机怎么解决 (58)
- fdatasync (59)
- 405 not allowed (56)
- 免备案虚拟主机zxhost (55)
- linux根据pid查看进程 (60)
- dhcp工具 (62)
- mysql 1045 (57)
- 宝塔远程工具 (56)
- ssh服务器拒绝了密码 请再试一次 (56)
- ubuntu卸载docker (56)
- linux查看nginx状态 (63)
- tomcat 乱码 (76)
- 2008r2激活序列号 (65)