huggingface vit训练代码 ,可以改dataset训练自己的数据

发布时间 2023-05-26 00:16:14作者: 高颜值的殺生丸

 

见代码:

from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST,CIFAR10
from torchvision.transforms import ToTensor
from torchvision.models import resnet101
from tqdm import tqdm
 
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("mps")  #torch.device("cpu")
 
# 加载 MNIST 数据集
train_dataset = CIFAR10(root="/Users/xinyuuliu/Desktop/test_python/", train=True, transform=ToTensor(), download=True)
test_dataset = CIFAR10(root="/Users/xinyuuliu/Desktop/test_python/", train=False, transform=ToTensor())


def collate_fn(batch):
    """
    对batch数据进行处理
    :param batch: [一个getitem的结果,getitem的结果,getitem的结果]
    :return: 元组
    """
    reviews,labels = zip(*batch)
    # print(reviews)
    # print(labels)
    # reviews = torch.Tensor(reviews)
    labels = torch.Tensor(labels)

    return reviews,labels
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True,collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False,collate_fn=collate_fn)

# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# image = Image.open(requests.get(url, stream=True).raw)

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
# print(model.get_output_embeddings)
# print(model.classifier)
model.classifier = nn.Linear(768,10)
print(model.classifier)
model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train(model, dataloader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    for inputs, labels in tqdm(dataloader, desc="Training"):
        # print(inputs)
        inputs = processor(images=inputs, return_tensors="pt")
        inputs['pixel_values'] = inputs['pixel_values'].to(device)
        labels = labels.to(device)
        # print(inputs['pixel_values'].shape)
        # print(labels.shape)
        optimizer.zero_grad()

        outputs = model(**inputs)
        logits = outputs.logits

        # print(logits,labels)
        loss = criterion(logits, labels.long())
        loss.backward()
        optimizer.step()
        # model predicts one of the 1000 ImageNet classes
        # predicted_class_idx = logits.argmax(-1).item()
        # print("Predicted class:", model.config.id2label[predicted_class_idx])
        running_loss += loss.item() * inputs['pixel_values'].size(0)
     
    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss

def evaluate(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs = processor(images=inputs, return_tensors="pt")
            inputs['pixel_values'] = inputs['pixel_values'].to(device)
            labels = labels.to(device)
             
            outputs = model(**inputs)
            logits = outputs.logits

            predicted= logits.argmax(-1)
             
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
     
    accuracy = correct / total * 100
    return accuracy


# 训练和评估
num_epochs = 10
 
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    train_loss = train(model, train_loader, optimizer, criterion)
    print(f"Training Loss: {train_loss:.4f}")

    test_acc = evaluate(model, test_loader)
    print(f"Test Accuracy: {test_acc:.2f}%")