Model Compression¶
Tổng quan¶
Bộ công cụ nén và tối ưu mô hình ML:
- HocRutGon - Knowledge Distillation (huấn luyện mô hình nhỏ từ mô hình lớn)
- CatTiaMoHinh - Pruning (cắt tỉa trọng số không quan trọng)
HocRutGon (Knowledge Distillation)¶
Huấn luyện mô hình nhỏ (student) từ mô hình lớn (teacher) bằng soft labels.
Cơ bản¶
from vietnamese_ai import HocRutGon
teacher = PhanLoai(thuat_toan="rung_ngau_nhien")
teacher.huan_luyen(X_train, y_train)
distiller = HocRutGon(
teacher=teacher,
nhiet_do=3.0,
alpha=0.5,
ham_loss="kl_divergence",
)
student = PhanLoai(thuat_toan="logistic")
ket_qua = distiller.huan_luyen(student, X_train, y_train, X_val, y_val)
print(f"Teacher accuracy: {ket_qua['teacher_acc']:.4f}")
print(f"Student accuracy: {ket_qua['student_acc']:.4f}")
print(f"Tỷ lệ nén: {ket_qua['ty_le_nen']:.1f}x")
Soft Labels và Temperature¶
Temperature scaling điều khiển mức độ "mịn" của xác suất teacher:
T = 1.0: Xác suất gốc (sharp)T > 1.0: Xác suất mịn hơn (soft), student học được nhiều thông tin hơnT < 1.0: Xác suất sắc nét hơn
# Temperature cao = soft labels mịn hơn
distiller = HocRutGon(teacher=teacher, nhiet_do=5.0)
# Temperature thấp = soft labels sắc nét
distiller = HocRutGon(teacher=teacher, nhiet_do=1.5)
Tham số¶
| Tham số | Mặc định | Mô tả |
|---|---|---|
nhiet_do |
3.0 |
Temperature scaling |
alpha |
0.5 |
Trọng số soft labels vs hard labels |
so_vong |
10 |
Số epochs |
ham_loss |
"kl_divergence" |
Loss function: kl_divergence, mse, cross_entropy |
Ensemble Distillation¶
Distillation từ nhiều teachers:
teacher_1 = PhanLoai(thuat_toan="rung_ngau_nhien")
teacher_2 = PhanLoai(thuat_toan="svm")
teacher_3 = PhanLoai(thuat_toan="gradient_boosting")
for t in [teacher_1, teacher_2, teacher_3]:
t.huan_luyen(X_train, y_train)
distiller = HocRutGon(teacher=teacher_1, nhiet_do=3.0)
student = PhanLoai(thuat_toan="logistic")
ket_qua = distiller.huan_luyen_ensemble(
student=student,
teachers=[teacher_1, teacher_2, teacher_3],
X=X_train,
y=y_train,
)
print(f"Student accuracy: {ket_qua['student_acc']:.4f}")
print(f"Số teachers: {ket_qua['so_teachers']}")
Thống kê¶
print(distiller.thong_ke())
# {"nhiet_do": 3.0, "alpha": 0.5, "ham_loss": "kl_divergence", "so_lan_distill": 1}
lich_su = distiller.lay_lich_su()
CatTiaMoHinh (Model Pruning)¶
Cắt tỉa trọng số không quan trọng để giảm kích thước và tăng tốc mô hình.
Magnitude-based Pruning (mặc định)¶
Xóa trọng số có giá trị tuyệt đối nhỏ nhất:
from vietnamese_ai import CatTiaMoHinh
pruner = CatTiaMoHinh(che_do="magnitude", ty_le=0.5)
ket_qua = pruner.cat_tia(model, X_train, y_train)
print(f"Pruned: {ket_qua['ty_le_prune']*100:.1f}% weights")
print(f"Hiệu suất trước: {ket_qua['hieu_suat_truoc']}")
print(f"Hiệu suất sau: {ket_qua['hieu_suat_sau']}")
Structured Pruning¶
Xóa nguyên neuron/layer dựa trên L2 norm:
pruner = CatTiaMoHinh(che_do="structured", ty_le=0.3)
ket_qua = pruner.cat_tia(model, X_train, y_train)
Iterative Pruning¶
Prune dần dần, mỗi lần prune một ít rồi retrain:
pruner = CatTiaMoHinh(
che_do="iterative",
ty_le=0.5,
so_vong_lap=5,
)
ket_qua = pruner.cat_tia(model, X_train, y_train)
Random Pruning¶
Prune ngẫu nhiên (baseline để so sánh):
Chế độ pruning¶
| Chế độ | Mô tả |
|---|---|
"magnitude" |
Xóa trọng số nhỏ ( |
"structured" |
Xóa nguyên neuron dựa trên L2 norm |
"iterative" |
Prune dần + retrain (Lottery Ticket) |
"random" |
Prune ngẫu nhiên (baseline) |
Tham số¶
| Tham số | Mặc định | Mô tả |
|---|---|---|
che_do |
"magnitude" |
Chiến lược pruning |
ty_le |
0.5 |
Tỷ lệ trọng số cần prune (0-1) |
so_vong_lap |
1 |
Số vòng lặp (iterative) |
phuc_hoi |
False |
Có phục hồi trọng số không |
Lấy mask và thống kê¶
mask = pruner.lay_mask()
print(f"Số trọng số giữ lại: {mask.sum()}/{len(mask)}")
lich_su = pruner.lay_lich_su()
print(pruner.thong_ke())
So sánh Distillation vs Pruning¶
| Đặc điểm | Distillation | Pruning |
|---|---|---|
| Mục đích | Mô hình nhỏ hơn | Giảm trọng số không cần thiết |
| Đầu vào | Teacher model + Data | Model + Data (tùy chọn) |
| Đầu ra | Student model mới | Model đã prune |
| Tốc độ | Chậm (train lại) | Nhanh |
| Chất lượng | Thường tốt hơn | Có thể giảm nhẹ |
Kết hợp Distillation + Pruning¶
from vietnamese_ai import HocRutGon, CatTiaMoHinh
distiller = HocRutGon(teacher=teacher, nhiet_do=3.0)
student = PhanLoai(thuat_toan="logistic")
distiller.huan_luyen(student, X_train, y_train)
pruner = CatTiaMoHinh(che_do="magnitude", ty_le=0.3)
pruner.cat_tia(student, X_train, y_train)
print(f"Student size: {pruner._dem_tham_so(student)}")