CLIP模型代码

发布时间 2023-10-01 18:13:52作者: asdasfagsag

近期看到了一篇用CLIP在我这个方向应用的文章,所以玩了一下CLIP,感觉效果还是很好的。

 

首先,github上的zero-shot代码

import os
import clip
import torch
from torchvision.datasets import CIFAR100

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

这里稍微介绍一下,模型的model.encode_xxx方法是用来计算特征的,这个和前向传播没什么差别,唯一不同的是需要多一些处理操作,上面的代码主要做的事情就是预测图片属于100类中的哪一类,找出了top-5的结果。

代码1:

with torch.no_grad():
        logits_per_image, logits_per_text = model(image_input, text_inputs)
     prob = logits_per_image.softmax(-1)

代码2:

with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
prob = (100.0 * image_features @ text_features.T).softmax(dim=-1)

代码1和代码2做的是一样的事情,都可以得到最后的预测,而且prob最后都是一样的,可以尝试一些。

 

第二个就是

我用cifar100测试集测试了一下VIT-B/32这个CLIP模型的zero-shot性能,最后是得到了61%的准确度,当然CLIP论文的作者在论文末尾也说了,未必对于所有目前流行的数据集都是完全zero-shot,不过它这个性能其实还是很不错的,虽然用0.4billion图片训练有点欺负人的意思。

import os
import clip
import torch
from torchvision.datasets import CIFAR100
from tqdm import tqdm

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"use device : {device}")
model, preprocess = clip.load('ViT-B/32', device)
# Download the dataset and the train=False that is mean we will download or load the test set
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
accuracy = 0
for item in tqdm(cifar100):
    image, class_id = item
    image_input = preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        logits_per_image, logits_per_text = model(image_input, text_inputs)
    temp_ans = logits_per_image.argmax().item()
    if temp_ans == class_id:
        accuracy += 1

accuracy/=10000
print(accuracy)