Summary
In this notebook, we will implement image captioning using a vision transformer. The application of pretrained ViT models for image captioning involves associating a textual description with an image, providing a comprehensive account of its contents. This procedure entails converting an image into a written narrative, establishing a connection between the domains of vision (image) and language (text). In the context of this document, we showcase how Vision Transformers (ViT) can execute this task when applied to images, utilizing the PyTorch backend as our primary technology. The goal is to demonstrate fine-tunning ViTs, for generating image captions without the necessity of retraining from scratch.
Python functions and data files needed to run this notebook are available via this link.
from transformers import ViTModel
from PIL import Image
from transformers import VisionEncoderDecoderModel, GPT2TokenizerFast, AutoFeatureExtractor, \
AutoTokenizer, TrainingArguments, Trainer
from PIL import Image
import os
import matplotlib.pyplot as plt
import numpy as np
from datasets import Dataset
import os
import torch
import numpy as np
import pandas as pd
# torchvision has several functions for image processing
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor, Resize
import pandas as pd
import requests
from io import BytesIO
The vision transformer first introduced in paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. The vision transformer was created to perform similar tasks of an NLP based transformer. The architect is similar:
Instead of token for words, fixed sized patches are used in image that are subset of images.
The model can be pre-trained on many types of datasets. The model we are using today has been pre-trained on the public ImageNet-21k dataset performs:
# Load up a pretrained Google vit model on HuggingFace
vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
vit_model
ViTModel model is pre-trained on imageNet: ViTEncoder is similar for LLM models which uses attention mechanism; there are query
, key
and value
. There is patch_embeddings
at the beginning which is different from LLM.
ViTModel has 12 encoders. Finally, there is a pooler
layer at the end.
We need feature extractor to convert images to a tensor. This can be done by AutoFeatureExtractor
library:
# Load feature extractor
feature_ext = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
We need to pay attention to how images are preprocessed to match how the model was pretrained
feature_ext
Feature extraction works as tokenizer in NLP. It takes raw string and turn them in tokens. Feature extraction takes a raw image and perform some pre-processing to convert it into tensor. Some parameters of feature extraction are do_normalize
, which is a normalization to certain mean and standard deviation and convert image to size of 224*224.
To load up an image, we can use Pillow's image object.
from PIL import *
import PIL.Image
img = Image.open('./image/1.jpg')
display(img)
If we use feature extraction:
import matplotlib.pyplot as plt
plt.imshow(feature_ext(img).pixel_values[0].transpose(1, 2, 0))
feature_ext(img).pixel_values[0].shape
Although the image is ugly to us, it eliminates a lot of noises for specific tasks
# Many weights are innitialized randomly, namely the cross attention weights
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
'google/vit-base-patch16-224-in21k', # :patch:16, image size: 224*224, pre-trained on: image net in21k
'distilgpt2') # decoder is `distilgpt2`
Vision transformer is like T5 because we have both encoder and decoder.
type(model.encoder)
type(model.decoder)
total_params = 0
for param in model.parameters():
total_params += param.numel()
print(f"The model has a combined {total_params:,} parameters")
# Instantiate a tokenizer
gpt2_tokenizer = GPT2TokenizerFast.from_pretrained('distilgpt2')
A new benchmark collection for sentence-based image description and search, consisting of 8,000 images that are each paired with five different captions which provide clear descriptions of the salient entities and events. … The images were chosen from six different Flickr groups, and tend not to contain any well-known people or locations, but were manually selected to depict a variety of scenes and situations
def captions_load(image_caption,image_path, min_caption=10, max_caption=50, num_image=10):
'''
optain image captions and image names
'''
with open(image_caption) as caption_file:
captions = caption_file.readlines()
map_captions = {}
map_captions_test = {}
text_data = []
text_data_test = []
if num_image>=len(captions):
num_image = len(captions)
# Loading up images from data set for training set
for line in captions[:num_image]:
line = line.rstrip("\n")
# Separate image name and captions using a tab
img_name, caption = line.split("\t")
# Five different captions is assigned to each image
# Each image name has a suffix `#(caption_number)`
img_name = img_name.split("#")[0]
img_name = os.path.join(image_path, img_name.strip())
if img_name.endswith("jpg"):
caption = caption.replace(' .', '').strip()
tokens = caption.strip().split()
if len(caption) < min_caption or len(caption) > max_caption:
continue
text_data.append(caption)
if img_name in map_captions:
map_captions[img_name].append(caption)
else:
map_captions[img_name] = [caption]
# Loading up images from data set for test set
for line in captions[num_image:]:
line = line.rstrip("\n")
# Separate image name and captions using a tab
img_name, caption = line.split("\t")
# Five different captions is assigned to each image
# Each image name has a suffix `#(caption_number)`
img_name = img_name.split("#")[0]
img_name = os.path.join(image_path, img_name.strip())
if img_name.endswith("jpg"):
caption = caption.replace(' .', '').strip()
tokens = caption.strip().split()
if len(caption) < min_caption or len(caption) > max_caption:
continue
text_data_test.append(caption)
if img_name in map_captions_test:
map_captions_test[img_name].append(caption)
else:
map_captions_test[img_name] = [caption]
return map_captions, text_data, map_captions_test, text_data_test
# Load the dataset
image_path = './image/Flickr8k_Dataset'
image_caption = './image/Flickr8k.token.txt'
map_captions, caption_only, map_captions_test, text_data_test = captions_load(image_caption,
image_path, num_image=7500)
list(map_captions.items())[:2]
list(map_captions_test.items())[:2]
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
normalize = Normalize(
mean=feature_extractor.image_mean,
std=feature_extractor.image_std
)
process_image = Compose(
[
RandomResizedCrop(list(feature_extractor.size.values())), # Data augmentation. Take a random resized crop of our image
ToTensor(), # Convert to pytorch tensor
normalize # normalize pixel values to look like images during pre-training
]
)
rows = []
# It is ok to have multiple captions per image becuase of data augmentation
for path, captions in map_captions.items():
for caption in captions:
rows.append({'path': path, 'caption': caption})
image_df = pd.DataFrame(rows)
image_dataset = Dataset.from_pandas(image_df)
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token
def image_preprocess(examples):
# ViT expects pixel_values instead of input_ids
examples['pixel_values'] = [process_image(Image.open(path)) for path in examples['path']]
# We are padding tokens here instead of using a datacollator
tokenized = gpt2_tokenizer(
examples['caption'], padding='max_length', max_length=10, truncation=True
)['input_ids']
# the output captions
examples['labels'] = [[l if l != gpt2_tokenizer.pad_token_id else -100 for l in t] for t in tokenized]
# delete unused keys
del examples['path']
del examples['caption']
return examples
image_dataset = image_dataset.map(image_preprocess, batched=True)
# Train test split
image_dataset = image_dataset.train_test_split(test_size=0.1)
image_dataset
# We set a pad token and a start token in our combined model to be the same as gpt2
model.config.pad_token = gpt2_tokenizer.pad_token
model.config.pad_token_id = gpt2_tokenizer.pad_token_id
model.config.decoder_start_token = gpt2_tokenizer.bos_token
model.config.decoder_start_token_id = gpt2_tokenizer.bos_token_id
Since the ViT model is very large, fine-tunning for all parameter will take long, so, we freeze some parameters to speed up the fine-tunning process.
# Get the number of layers
config = model.config
num_layers = config.encoder.num_hidden_layers
print("Number of hidden layers in the VisionEncoderDecoderConfig model:", num_layers)
## to speed up training, freeze all encoder layers except last one
#ir = 0
#for name, param in model.encoder.named_parameters():
# ir += 1
# #print(name)
# if 'encoder.layer.3' in name: # freeze 3 layers in the ViT
# print(f'Parameter {ir}: encoder.layer.3')
# break
# param.requires_grad = False # disable training in ViT
epochs = 3
batch_size = 5
from transformers import set_seed
set_seed(42)
training_args = TrainingArguments(
output_dir='./caption_image', # The output directory
overwrite_output_dir=True, # overwrite the content of the output directory
num_train_epochs=epochs, # number of training epochs
per_device_train_batch_size=batch_size, # batch size for training
per_device_eval_batch_size=batch_size, # batch size for evaluation
load_best_model_at_end=True,
log_level='info',
logging_steps=50,
evaluation_strategy='epoch',
save_strategy='epoch',
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=image_dataset['train'],
eval_dataset=image_dataset['test'],
)
import os
os.environ['WANDB_DISABLED'] = 'true'
trainer.evaluate()
trainer.train()
# the loss decline is starting to slow down. This is a good indication that we may want to try training on more data
trainer.save_model()
# loading model and config from pretrained folder
finetuned_model = VisionEncoderDecoderModel.from_pretrained('./caption_image')
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
normalize = Normalize(
mean=feature_extractor.image_mean,
std=feature_extractor.image_std
)
# Create a new composition that doesn't crop images for inference to make it easier for the model
inferenceprocess_image = Compose(
[
RandomResizedCrop(list(feature_extractor.size.values())),
ToTensor(),
normalize
]
)
# a helper function to caption images from the web or a file path
def caption_image(m,path,num_beams=3,max_length=15,top_k=10, num_return_sequences=5):
if 'http' in path:
response = requests.get(path)
img = Image.open(BytesIO(response.content))
image_matrix = inferenceprocess_image(img).unsqueeze(0)
else:
img = Image.open(path)
image_matrix = inferenceprocess_image(img).unsqueeze(0)
generated = m.generate(
image_matrix,
num_beams=num_beams,
max_length=max_length,
early_stopping=True,
do_sample=True,
top_k=top_k,
num_return_sequences=num_return_sequences,
)
caption_options = [gpt2_tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated]
display(img)
return caption_options, generated, image_matrix
captions, generated, image_matrix = caption_image( # Out of sample photo
finetuned_model, list(map_captions_test.items())[0][0]
)
captions
captions, generated, image_matrix = caption_image( # Another one
finetuned_model, list(map_captions_test.items())[1][0]
)
captions
captions, generated, image_matrix = caption_image( # from our flicker dataset
finetuned_model,
list(map_captions_test.items())[2][0]
)
captions
url = "https://raw.githubusercontent.com/MehdiRezvandehy/Machine-Learning-Course-University-of-Calgary/master/Images/2308978137_bfe776d541.jpg"
captions, generated, image_matrix = caption_image( # Out of sample photo
finetuned_model, url
)
captions
# loading model and config from pretrained folder
finetuned_model = VisionEncoderDecoderModel.from_pretrained('./caption_image')
non_finetuned = VisionEncoderDecoderModel.from_encoder_decoder_pretrained('google/vit-base-patch16-224-in21k',
'distilgpt2')
captions, generated, image_matrix = caption_image( # Out of sample photo
non_finetuned, list(map_captions_test.items())[0][0]
)
captions
captions, generated, image_matrix = caption_image( # Another one
non_finetuned, list(map_captions_test.items())[1][0]
)
captions
captions, generated, image_matrix = caption_image( # from our flicker dataset
non_finetuned,
list(map_captions_test.items())[2][0]
)
captions
url = "https://raw.githubusercontent.com/MehdiRezvandehy/Machine-Learning-Course-University-of-Calgary/master/Images/2308978137_bfe776d541.jpg"
captions, generated, image_matrix = caption_image( # Out of sample photo
non_finetuned, url
)
captions
list(map_captions_test.items())[3][0]