FINE TUNING LÀ GÌ

  -  
1. Introduction

1.1 Fine-tuning là gì ?

Chắc hẳn phần lớn ai làm việc với các mã sản phẩm trong deep learning hồ hết đã nghe/quen với quan niệm Transfer learning với Fine tuning. Quan niệm tổng quát: Transfer learning là tận dụng tri thức học được từ một vấn đề để áp dụng vào 1 vấn đề có liên quan khác. Một ví dụ đối chọi giản: thay bởi vì train 1 mã sản phẩm mới hoàn toàn cho bài toán phân các loại chó/mèo, fan ta rất có thể tận dụng 1 model đã được train bên trên ImageNet dataset với hằng triệu ảnh. Pre-trained mã sản phẩm này sẽ tiến hành train tiếp trên tập dataset chó/mèo, quy trình train này ra mắt nhanh hơn, công dụng thường tốt hơn. Có khá nhiều kiểu Transfer learning, các chúng ta cũng có thể tham khảo trong bài bác này: Tổng thích hợp Transfer learning. Trong bài bác này, mình đã viết về 1 dạng transfer learning phổ biến: Fine-tuning.

Bạn đang xem: Fine tuning là gì

Hiểu solo giản, fine-tuning là chúng ta lấy 1 pre-trained model, tận dụng 1 phần hoặc toàn bộ các layer, thêm/sửa/xoá 1 vài ba layer/nhánh để tạo nên 1 mã sản phẩm mới. Thường những layer đầu của model được freeze (đóng băng) lại - tức weight những layer này sẽ không bị đổi khác giá trị trong quá trình train. Nguyên nhân bởi những layer này đã có tác dụng trích xuất thông tin mức trìu tượng thấp , kĩ năng này được học từ quy trình training trước đó. Ta freeze lại để tận dụng được năng lực này và giúp vấn đề train ra mắt nhanh hơn (model chỉ cần update weight ở những layer cao). Có không ít các Object detect model được thi công dựa trên các Classifier model. VD Retina model (Object detect) được chế tạo với backbone là Resnet.

*

1.2 lý do pytorch thay bởi Keras ?

Chủ đề nội dung bài viết hôm nay, mình sẽ giải đáp fine-tuning Resnet50 - 1 pre-trained model được hỗ trợ sẵn vào torchvision của pytorch. Tại sao là pytorch mà chưa hẳn Keras ? nguyên nhân bởi vấn đề fine-tuning model trong keras rất 1-1 giản. Dưới đây là 1 đoạn code minh hoạ cho vấn đề xây dựng 1 Unet dựa trên Resnet trong Keras:

from tensorflow.keras import applicationsresnet = applications.resnet50.ResNet50()layer_3 = resnet.get_layer("activation_9").outputlayer_7 = resnet.get_layer("activation_21").outputlayer_13 = resnet.get_layer("activation_39").outputlayer_16 = resnet.get_layer("activation_48").output#Adding outputs decoder with encoder layersfcn1 = Conv2D(...)(layer_16)fcn2 = Conv2DTranspose(...)(fcn1)fcn2_skip_connected = Add()()fcn3 = Conv2DTranspose(...)(fcn2_skip_connected)fcn3_skip_connected = Add()()fcn4 = Conv2DTranspose(...)(fcn3_skip_connected)fcn4_skip_connected = Add()()fcn5 = Conv2DTranspose(...)(fcn4_skip_connected)Unet = Model(inputs = resnet.input, outputs=fcn5)Bạn rất có thể thấy, fine-tuning model trong Keras thực sự rất đơn giản, dễ làm, dễ dàng hiểu. Việc địa chỉ thêm những nhánh rất đơn giản bởi cú pháp solo giản. Trong pytorch thì ngược lại, xây dựng 1 model Unet tương tự sẽ tương đối vất vả cùng phức tạp. Tín đồ mới học sẽ gặp khó khăn vày trên mạng ko nhiều những hướng dẫn cho việc này. Vậy nên bài xích này mình sẽ hướng dẫn chi tiết cách fine-tune vào pytorch để vận dụng vào bài toán Visual Saliency prediction

2. Visual Saliency prediction

2.1 What is Visual Saliency ?

*

Khi nhìn vào 1 bức ảnh, mắt thường xuyên có xu thế tập trung quan sát vào 1 vài cửa hàng chính. Ảnh trên đó là 1 minh hoạ, màu quà được sử dụng để thể hiện mức độ thu hút. Saliency prediction là việc mô rộp sự tập trung của mắt bạn khi quan sát 1 bức ảnh. Vậy thể, bài xích toán yên cầu xây dựng 1 model, model này nhận ảnh đầu vào, trả về 1 mask mô bỏng mức độ thu hút. Như vậy, model nhận vào 1 input image và trả về 1 mask có kích thước tương đương.

Để rõ rộng về câu hỏi này, bạn cũng có thể đọc bài: Visual Saliency Prediction with Contextual Encoder-Decoder Network.Dataset phổ cập nhất: SALICON DATASET

2.2 Unet

Note: Bạn có thể bỏ qua phần này nếu vẫn biết về Unet

Đây là 1 trong bài toán Image-to-Image. Để xử lý bài toán này, mình sẽ xây dựng dựng 1 mã sản phẩm theo bản vẽ xây dựng Unet. Unet là 1 trong kiến trúc được sử dụng nhiều trong câu hỏi Image-to-image như: semantic segmentation, tự động hóa color, super resolution ... Phong cách thiết kế của Unet bao gồm điểm tựa như với kiến trúc Encoder-Decoder đối xứng, được thêm những skip connection tự Encode sang Decode tương ứng. Về cơ bản, những layer càng tốt càng trích xuất tin tức ở mức trìu tượng cao, điều đó đồng nghĩa với việc các thông tin nấc trìu tượng phải chăng như đường nét, màu sắc sắc, độ phân giải... Có khả năng sẽ bị mất mát đi trong quá trình lan truyền. Tín đồ ta thêm các skip-connection vào để giải quyết vấn đề này.

Với phần Encode, feature-map được downscale bằng các Convolution. Ngược lại, ở vị trí decode, feature-map được upscale bởi các Upsampling layer, trong bài xích này mình sử dụng những Convolution Transpose.

*

2.3 Resnet

Để giải quyết bài toán, mình sẽ xây dựng dựng mã sản phẩm Unet với backbone là Resnet50. Các bạn nên khám phá về Resnet nếu chưa chắc chắn về phong cách thiết kế này. Hãy quan sát hình minh hoạ dưới đây. Resnet50 được phân thành các khối béo . Unet được phát hành với Encoder là Resnet50. Ta sẽ lấy ra output của từng khối, tạo những skip-connection kết nối từ Encoder sang Decoder. Decoder được tạo bởi các Convolution Transpose layer (xen kẽ trong số đó là những lớp Convolution nhằm mục đích mục đích bớt số chanel của feature bản đồ -> giảm con số weight đến model).

Theo ý kiến cá nhân, pytorch rất dễ dàng code, dễ hiểu hơn không hề ít so cùng với Tensorflow 1.x hoặc ngang ngửa Keras. Tuy nhiên, việc fine-tuning model vào pytorch lại cực nhọc hơn không ít so với Keras. Trong Keras, ta không đề nghị quá vồ cập tới con kiến trúc, luồng cách xử trí của model, chỉ cần lấy ra các output tại 1 số layer nhất định làm skip-connection, ghép nối và tạo thành ra mã sản phẩm mới.

*

Trong pytorch thì ngược lại, bạn cần hiểu được luồng xử lý và copy code đầy đủ layer ước ao giữ lại trong model mới. Hình bên trên là code của resnet50 vào torchvision. Bạn có thể tham khảo link: torchvision-resnet50. Bởi vậy khi phát hành Unet như bản vẽ xây dựng đã mô tả bên trên, ta cần bảo đảm an toàn đoạn code từ bỏ Conv1 -> Layer4 không xẩy ra thay đổi. Hãy đọc phần tiếp theo để làm rõ hơn.

Xem thêm: Hướng Dẫn Chơi Game Mobile Trên Pc, Cách Chơi Game Android Trên Pc Không Cần Giả Lập

3. Code

Tất cả code của chính bản thân mình được gói gọn trong tệp tin notebook Salicon_main.ipynb. Chúng ta cũng có thể tải về với run code theo link github: github/trungthanhnguyen0502 . Trong nội dung bài viết mình đã chỉ gửi ra phần đa đoạn code chính.

Import các package

import albumentations as Aimport numpy as npimport torchimport torchvisionimport torch.nn as nn import torchvision.transforms as Timport torchvision.models as modelsfrom torch.utils.data import DataLoader, Datasetimport ....

3.1 utils functions

Trong pytorch, tài liệu có đồ vật tự dimension khác với Keras/TF/numpy. Thường thì với numpy giỏi keras, hình ảnh có dimension theo vật dụng tự (batchsize,h,w,chanel)(batchsize, h, w, chanel)(batchsize,h,w,chanel). Sản phẩm công nghệ tự trong Pytorch trái lại là (batchsize,chanel,h,w)(batchsize, chanel, h, w)(batchsize,chanel,h,w). Mình sẽ xây dựng dựng 2 hàm toTensor và toNumpy để thay đổi qua lại thân hai format này.

def toTensor(np_array, axis=(2,0,1)): return torch.tensor(np_array).permute(axis)def toNumpy(tensor, axis=(1,2,0)): return tensor.detach().cpu().permute(axis).numpy() ## display one image in notebookdef plot_img(img): ... ## display multi imagedef plot_imgs(imgs): ...

3.2 Define model

3.2.1 Conv and Deconv

Mình sẽ xây dựng 2 function trả về module Convolution và Convolution Transpose (Deconv)

def Deconv(n_input, n_output, k_size=4, stride=2, padding=1): Tconv = nn.ConvTranspose2d( n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False) block = < Tconv, nn.BatchNorm2d(n_output), nn.LeakyReLU(inplace=True), > return nn.Sequential(*block) def Conv(n_input, n_output, k_size=4, stride=2, padding=0, bn=False, dropout=0): conv = nn.Conv2d( n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False) block = < conv, nn.BatchNorm2d(n_output), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout) > return nn.Sequential(*block)

3.2.2 Unet model

Init function: ta vẫn copy các layer đề nghị giữ trường đoản cú resnet50 vào unet. Tiếp nối khởi tạo những Conv / Deconv layer và những layer đề xuất thiết.

Forward function: cần bảo vệ luồng cách xử lý của resnet50 được không thay đổi giống code cội (trừ Fully-connected layer). Tiếp nối ta ghép nối những layer lại theo phong cách xây dựng Unet đã miêu tả trong phần 2.

Tạo model: nên load resnet50 cùng truyền vào Unet. Đừng quên Freeze các layer của resnet50 trong Unet.

Xem thêm: Cây Bàng Non Là Gì - Vì Anh Thương Em Như Cây Bàng Non Là Sao

class Unet(nn.Module): def __init__(self, resnet): super().__init__() self.conv1 = resnet.conv1 self.bn1 = resnet.bn1 self.relu = resnet.relu self.maxpool = resnet.maxpool self.tanh = nn.Tanh() self.sigmoid = nn.Sigmoid() # get some layer from resnet lớn make skip connection self.layer1 = resnet.layer1 self.layer2 = resnet.layer2 self.layer3 = resnet.layer3 self.layer4 = resnet.layer4 # convolution layer, use to reduce the number of channel => reduce weight number self.conv_5 = Conv(2048, 512, 1, 1, 0) self.conv_4 = Conv(1536, 512, 1, 1, 0) self.conv_3 = Conv(768, 256, 1, 1, 0) self.conv_2 = Conv(384, 128, 1, 1, 0) self.conv_1 = Conv(128, 64, 1, 1, 0) self.conv_0 = Conv(32, 1, 3, 1, 1) # deconvolution layer self.deconv4 = Deconv(512, 512, 4, 2, 1) self.deconv3 = Deconv(512, 256, 4, 2, 1) self.deconv2 = Deconv(256, 128, 4, 2, 1) self.deconv1 = Deconv(128, 64, 4, 2, 1) self.deconv0 = Deconv(64, 32, 4, 2, 1) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) skip_1 = x x = self.maxpool(x) x = self.layer1(x) skip_2 = x x = self.layer2(x) skip_3 = x x = self.layer3(x) skip_4 = x x5 = self.layer4(x) x5 = self.conv_5(x5) x4 = self.deconv4(x5) x4 = torch.cat(, dim=1) x4 = self.conv_4(x4) x3 = self.deconv3(x4) x3 = torch.cat(, dim=1) x3 = self.conv_3(x3) x2 = self.deconv2(x3) x2 = torch.cat(, dim=1) x2 = self.conv_2(x2) x1 = self.deconv1(x2) x1 = torch.cat(, dim=1) x1 = self.conv_1(x1) x0 = self.deconv0(x1) x0 = self.conv_0(x0) x0 = self.sigmoid(x0) return x0 device = torch.device("cuda")resnet50 = models.resnet50(pretrained=True)model = Unet(resnet50)model.to(device)## Freeze resnet50"s layers in Unetfor i, child in enumerate(model.children()): if i 7: for param in child.parameters(): param.requires_grad = False

3.3 Dataset and Dataloader

Dataset trả nhấn 1 list các image_path với mask_dir, trả về image và mask tương ứng.

Define MaskDataset

class MaskDataset(Dataset): def __init__(self, img_fns, mask_dir, transforms=None): self.img_fns = img_fns self.transforms = transforms self.mask_dir = mask_dir def __getitem__(self, idx): img_path = self.img_fns img_name = img_path.split("/")<-1>.split(".")<0> mask_fn = f"self.mask_dir/img_name.png" img = cv2.imread(img_path) mask = cv2.imread(mask_fn) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) if self.transforms: sample = "image": img, "mask": mask sample = self.transforms(**sample) img = sample<"image"> mask = sample<"mask"> # to Tensor img = img/255.0 mask = np.expand_dims(mask, axis=-1)/255.0 mask = toTensor(mask).float() img = toTensor(img).float() return img, mask def __len__(self): return len(self.img_fns)Test dataset

img_fns = glob("./Salicon_dataset/image/train/*.jpg")mask_dir = "./Salicon_dataset/mask/train"train_transform = A.Compose(< A.Resize(width=256,height=256, p=1), A.RandomSizedCrop(<240,256>, height=256, width=256, p=0.4), A.HorizontalFlip(p=0.5), A.Rotate(limit=(-10,10), p=0.6),>)train_dataset = MaskDataset(img_fns, mask_dir, train_transform)train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=True)# demo datasetimg, mask = next(iter(train_dataset))img = toNumpy(img)mask = toNumpy(mask)<:,:,0>img = (img*255.0).astype(np.uint8)mask = (mask*255.0).astype(np.uint8)heatmap_img = cv2.applyColorMap(mask, cv2.COLORMAP_JET)combine_img = cv2.addWeighted(img, 0.7, heatmap_img, 0.3, 0)plot_imgs(

3.4 Train model

Vì bài xích toán đơn giản dễ dàng và làm cho dễ hiểu, mình đang train theo cách đơn giản và dễ dàng nhất, không validate vào qúa trình train mà chỉ lưu mã sản phẩm sau 1 số epoch tốt nhất định

train_params = optimizer = torch.optim.Adam(train_params, lr=0.001, betas=(0.9, 0.99))epochs = 5model.train()saved_dir = "model"os.makedirs(saved_dir, exist_ok=True)loss_function = nn.MSELoss(reduce="mean")for epoch in range(epochs): for imgs, masks in tqdm(train_loader): imgs_gpu = imgs.to(device) outputs = model(imgs_gpu) masks = masks.to(device) loss = loss_function(outputs, masks) loss.backward() optimizer.step()

3.5 test model

img_fns = glob("./Salicon_dataset/image/val/*.jpg")mask_dir = "./Salicon_dataset/mask/val"val_transform = A.Compose(< A.Resize(width=256,height=256, p=1), A.HorizontalFlip(p=0.5),>)model.eval()val_dataset = MaskDataset(img_fns, mask_dir, val_transform)val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, drop_last=True)imgs, mask_targets = next(iter(val_loader))imgs_gpu = imgs.to(device)mask_outputs = model(imgs_gpu)mask_outputs = toNumpy(mask_outputs, axis=(0,2,3,1))imgs = toNumpy(imgs, axis=(0,2,3,1))mask_targets = toNumpy(mask_targets, axis=(0,2,3,1))for i, img in enumerate(imgs): img = (img*255.0).astype(np.uint8) mask_output = (mask_outputs*255.0).astype(np.uint8) mask_target = (mask_targets*255.0).astype(np.uint8) heatmap_label = cv2.applyColorMap(mask_target, cv2.COLORMAP_JET) heatmap_pred = cv2.applyColorMap(mask_output, cv2.COLORMAP_JET) origin_img = cv2.addWeighted(img, 0.7, heatmap_label, 0.3, 0) predict_img = cv2.addWeighted(img, 0.7, heatmap_pred, 0.3, 0) result = np.concatenate((img,origin_img, predict_img),axis=1) plot_img(result)Kết quả thu được:

*

Đây là bài bác toán dễ dàng và đơn giản nên mình chú trọng vào quy trình và phương thức fine tuning vào pytorch hơn là đi sâu vào giải quyết và xử lý bài toán. Cảm ơn các bạn đã đọc

4. Reference

Dataset: salicon.net

Code bài bác viết: https://github.com/trungthanhnguyen0502/-ceds.edu.vn-Visual-Saliency-prediction

Resnet50 torchvision code: torchvision-resnet

Bài viết cùng chủ đề Visual saliency: Visual Saliency Prediction with Contextual Encoder-Decoder Network!

Theo dõi các bài viết chuyên sâu về AI/Deep learning tại: Vietnam AI links Sharing Community