前言

大型語言模型(LLM)很強,但不是每個任務都需要動用 GPT-4 等級的怪獸。很多時候,一個 7B 甚至 3B 參數的小模型,經過針對性的 fine-tune,就能在特定任務上達到令人驚豔的效果——而且推論成本只有大模型的十分之一甚至百分之一。

問題是,全參數 fine-tune 一個 7B 模型需要的 GPU 記憶體動輒 50GB 以上,對多數團隊來說門檻太高。這就是 LoRA(Low-Rank Adaptation)出場的時機。LoRA 讓你只訓練極少量的參數(通常不到原模型的 1%),就能達到接近全參數微調的效果。

這篇文章會從 LoRA 的原理講起,帶你走一遍完整的 fine-tune 流程:從資料集準備、訓練設定、到最後用 Hugging Face 部署推論。我自己用這套流程在公司內部做過好幾個專案,包括客服分類、程式碼審查建議、以及技術文件摘要,效果都很不錯。

LoRA 原理簡介

為什麼全參數微調太貴

一個標準的 Transformer 模型,每一層都有巨大的權重矩陣。以 LLaMA-7B 為例,光是 attention 層的 Q、K、V 投影矩陣就有數十億參數。全參數微調意味著:

  • 需要儲存完整的模型權重梯度(FP32 約 28GB)
  • 需要 optimizer states(Adam 需要額外 2 倍)
  • 實際訓練至少需要 80GB VRAM

LoRA 的核心想法

LoRA 的論文觀察到一個關鍵現象:fine-tune 時,權重的變化量(delta W)其實是低秩的(low-rank)。既然如此,與其直接更新整個矩陣 W,不如把 delta W 分解成兩個小矩陣的乘積:

原始前向傳播:y = Wx
LoRA 修改後:  y = Wx + BAx

其中: W: 原始權重矩陣 (d × d),凍結不動 A: 降維矩陣 (d × r),r << d B: 升維矩陣 (r × d),初始化為 0

r 就是所謂的 rank,通常設 8、16、32 就夠了。這樣一來,可訓練參數從 d² 降到 2dr,對一個 4096 維的層來說,rank=16 就是從 1600 萬參數降到 13 萬——減少了 99%。

關鍵超參數

# LoRA 的兩個核心超參數
lora_config = {
    "r": 16,           # rank,越大表達力越強但參數越多
    "lora_alpha": 32,  # scaling factor,通常設為 2*r
    "lora_dropout": 0.05,
    "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"],  # 要套用 LoRA 的模組
}

我的經驗法則:

  • 簡單分類任務:r=8 就夠
  • 生成任務(摘要、翻譯):r=16~32
  • 複雜推理任務:r=64,但要注意過擬合

資料集準備

資料集品質決定了 fine-tune 的成敗。我見過太多人花大把時間調超參數,結果問題出在資料本身。

資料格式

Hugging Face 的 trl 套件支援多種格式,最常用的是 instruction 格式:

{
  "instruction": "請將以下客服訊息分類為:退貨、帳務、技術問題、其他",
  "input": "我上週買的耳機壞了,想要退貨退款",
  "output": "退貨"
}

或者用 chat 格式:

{
  "messages": [
    {"role": "system", "content": "你是一個客服分類助手"},
    {"role": "user", "content": "我上週買的耳機壞了,想要退貨退款"},
    {"role": "assistant", "content": "退貨"}
  ]
}

資料集準備腳本

import json
import pandas as pd
from datasets import Dataset
from sklearn.model_selection import train_test_split

# 假設你有一份 CSV 格式的標註資料 df = pd.read_csv("labeled_tickets.csv") # columns: text, category

def format_instruction(row): return { "messages": [ { "role": "system", "content": "你是一個客服分類助手。請將客戶訊息分類為以下類別之一:退貨、帳務、技術問題、其他。只回覆類別名稱。" }, {"role": "user", "content": row["text"]}, {"role": "assistant", "content": row["category"]} ] }

# 轉換格式 formatted = df.apply(format_instruction, axis=1).tolist()

# 切分訓練/驗證集(90/10) train_data, val_data = train_test_split(formatted, test_size=0.1, random_state=42)

# 儲存為 JSONL for split_name, data in [("train", train_data), ("val", val_data)]: with open(f"{split_name}.jsonl", "w", encoding="utf-8") as f: for item in data: f.write(json.dumps(item, ensure_ascii=False) + "\n")

print(f"訓練集: {len(train_data)} 筆, 驗證集: {len(val_data)} 筆")

資料品質檢查清單

在開始訓練前,我會做這些檢查:

# 資料品質檢查
def check_dataset(data):
    # 1. 檢查空值
    empty_count = sum(1 for d in data if not d["messages"][-1]["content"].strip())
    print(f"空回覆數量: {empty_count}")

# 2. 檢查類別分佈 from collections import Counter labels = [d["messages"][-1]["content"] for d in data] dist = Counter(labels) print(f"類別分佈: {dict(dist)}")

# 3. 檢查文字長度 lengths = [len(d["messages"][1]["content"]) for d in data] print(f"輸入長度 - 平均: {sum(lengths)/len(lengths):.0f}, " f"最長: {max(lengths)}, 最短: {min(lengths)}")

# 4. 檢查重複 texts = [d["messages"][1]["content"] for d in data] dup_count = len(texts) - len(set(texts)) print(f"重複資料: {dup_count} 筆")

check_dataset(train_data)

訓練流程

環境安裝

pip install torch transformers peft trl datasets accelerate bitsandbytes

完整訓練腳本

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from datasets import load_dataset

# === 1. 模型與量化設定 === model_name = "meta-llama/Llama-3.1-8B-Instruct"

# 4-bit 量化,大幅降低記憶體需求 bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, )

# 載入模型 model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb_config, device_map="auto", trust_remote_code=True, ) model = prepare_model_for_kbit_training(model)

# 載入 tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right"

# === 2. LoRA 設定 === lora_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], )

model = get_peft_model(model, lora_config) model.print_trainable_parameters() # 輸出類似:trainable params: 13,631,488 || all params: 8,043,724,800 || trainable%: 0.1695

# === 3. 載入資料集 === dataset = load_dataset("json", data_files={ "train": "train.jsonl", "validation": "val.jsonl", })

# === 4. 訓練參數 === training_args = TrainingArguments( output_dir="./lora-output", num_train_epochs=3, per_device_train_batch_size=4, gradient_accumulation_steps=4, # 等效 batch size = 16 learning_rate=2e-4, warmup_ratio=0.1, lr_scheduler_type="cosine", logging_steps=10, save_strategy="epoch", evaluation_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="eval_loss", bf16=True, gradient_checkpointing=True, # 省記憶體 report_to="wandb", # 可選,用 wandb 追蹤訓練 )

# === 5. 開始訓練 === trainer = SFTTrainer( model=model, args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["validation"], tokenizer=tokenizer, max_seq_length=1024, )

trainer.train()

# === 6. 儲存 LoRA adapter === trainer.model.save_pretrained("./lora-adapter") tokenizer.save_pretrained("./lora-adapter") print("LoRA adapter 已儲存")

訓練監控重點

訓練過程中我會關注幾個指標:

  • train_loss:應該穩定下降,如果震盪劇烈就調低 learning rate
  • eval_loss:如果 train_loss 持續下降但 eval_loss 開始上升,就是過擬合了
  • learning rate schedule:cosine schedule 通常比 linear 好
# 訓練完成後檢查結果
import matplotlib.pyplot as plt

history = trainer.state.log_history train_loss = [(h["step"], h["loss"]) for h in history if "loss" in h] eval_loss = [(h["step"], h["eval_loss"]) for h in history if "eval_loss" in h]

plt.figure(figsize=(10, 5)) plt.plot(zip(train_loss), label="Train Loss") plt.plot(zip(eval_loss), label="Eval Loss", marker="o") plt.xlabel("Step") plt.ylabel("Loss") plt.legend() plt.title("Training Progress") plt.savefig("training_curve.png")

Hugging Face 整合與部署

上傳到 Hugging Face Hub

from huggingface_hub import login

login(token="hf_your_token_here")

# 上傳 LoRA adapter(不是完整模型,只有幾十 MB) trainer.model.push_to_hub("your-username/my-lora-adapter") tokenizer.push_to_hub("your-username/my-lora-adapter")

推論程式碼

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch

# 載入基礎模型 base_model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-3.1-8B-Instruct", torch_dtype=torch.bfloat16, device_map="auto", )

# 套用 LoRA adapter model = PeftModel.from_pretrained(base_model, "your-username/my-lora-adapter") tokenizer = AutoTokenizer.from_pretrained("your-username/my-lora-adapter")

# 推論 messages = [ {"role": "system", "content": "你是一個客服分類助手。"}, {"role": "user", "content": "我的訂單已經一週了還沒收到"}, ]

input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")

with torch.no_grad(): output = model.generate(input_ids, max_new_tokens=50, temperature=0.1)

result = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True) print(f"分類結果: {result}")

合併模型以加速推論

如果確定 adapter 效果滿意,可以把 LoRA 權重合併回基礎模型,推論時就不需要額外的 adapter loading 開銷:

# 合併 LoRA adapter 到基礎模型
merged_model = model.merge_and_unload()

# 儲存完整模型 merged_model.save_pretrained("./merged-model") tokenizer.save_pretrained("./merged-model")

用 vLLM 部署高效推論服務

# 安裝 vLLM
pip install vllm

# 啟動推論伺服器(支援 OpenAI 相容 API) python -m vllm.entrypoints.openai.api_server \ --model ./merged-model \ --host 0.0.0.0 \ --port 8000 \ --max-model-len 2048

# 用 OpenAI SDK 呼叫
import openai

client = openai.OpenAI(base_url="http://localhost:8000/v1", api_key="dummy")

response = client.chat.completions.create( model="./merged-model", messages=[ {"role": "system", "content": "你是一個客服分類助手。"}, {"role": "user", "content": "我想取消訂閱"}, ], temperature=0.1, ) print(response.choices[0].message.content)

小結

LoRA 讓小團隊也能玩得起 fine-tune,關鍵心得整理:

  1. 資料品質 > 資料量 > 模型大小 > 超參數:先把資料弄乾淨再說
  2. 從小模型開始:3B 模型 + LoRA 往往比你想像的好用,不要一上來就拿 70B
  3. r 不用太大:大部分任務 r=16 就夠,r=64 通常過度了
  4. 務必做驗證集監控:過擬合是 fine-tune 最常見的問題
  5. 合併後用 vLLM 部署:生產環境的推論效能差異巨大

延伸閱讀建議: