May 16, 2020

How To Tag Any Image Using Deep Learning With Pytorch Lightning!

Tyler Folkman
Image for post

Photo by Michael Rogers on Unsplash

You probably have photos, right?

You probably want those photos tagged automatically for you, right?

But you also don’t want to write a ton of code to do so.

Read on to learn how to use deep learning and Pytorch to tag any photo with less than 60 lines of code. The best part is, you’ll only have to change about 3 lines of code to get it to work for your own images!

Tagging Monkeys

An extremely common machine learning problem is to classify or tag an image. Image classification is when you have a predefined set of classes for which you want to assign images.

Let’s say you work at a Zoo and are always forgetting the names of all the monkey species. It would be great if you had a way to automatically classify various pictures of monkeys with the appropriate species.

Why monkeys, you ask? Because there is an available dataset on Kaggle. :) This dataset contains about 1,400 images of 10 different species of monkeys. Here is a picture of the white-headed capuchin:

Image for post

And one of the patas monkey:

Image for post

Having data is key. For your own problem, make sure you have some images that are already tagged. My recommendation would be to get at least 50 tagged images per class.

Once you have your images, let’s get them organized correctly. You will need to create two folders: “training” and “validation”. Your photos in the training folder will be used to train our deep learning model. The validation photos will be used to make sure our model is tuned well.

Within each folder, make a folder for each tag you have. For our monkeys, we have 10 tags, we will call them n0-n9. Thus, our folder structure looks like this:

└── training
   ├── n0
   ├── n1
   ├── n2
   ├── n3
   ├── n4
   ├── n5
   ├── n6
   ├── n7
   ├── n8
   └── n9
└── validation
   ├── n0
   ├── n1
   ├── n2
   ├── n3
   ├── n4
   ├── n5
   ├── n6
   ├── n7
   ├── n8
   └── n9

Then, place the appropriate images within each folder. Maybe put 70% of your tagged images in training, 20% in validation, and leave 10% out for testing.

We will also maintain a mapping from n0-n9 to the actual species names, so we don’t forget:

Label,  Latin Name           , Common Name                  
n0   , alouatta_palliata     , mantled_howler              
n1   , erythrocebus_patas    , patas_monkey                  
n2   , cacajao_calvus     , bald_uakari                  
n3   , macaca_fuscata     , japanese_macaque            
n4   , cebuella_pygmea     , pygmy_marmoset              
n5   , cebus_capucinus     , white_headed_capuchin        
n6   , mico_argentatus     , silvery_marmoset            
n7   , saimiri_sciureus     , common_squirrel_monkey      
n8   , aotus_nigriceps     , black_headed_night_monkey    
n9   , trachypithecus_johnii , nilgiri_langur

Build Your Model

ResNet-50

An extremely popular neural network architecture for tagging images is ResNet-50. It does a good job balancing accuracy and complexity. I won’t go into depth on this deep learning model, but you can learn more here. For our purposes, just know its a really good model for image classification and you should be able to train it in a reasonable time if you have access to a GPU. If you don’t, take a look at Google Colab to get access to free GPU resources.

Fine Tuning

One of the tricks we will use when training our model will be to use the idea of fine-tuning to hopefully be able to learn how to accurately tag with only a few examples.

Fine-tuning starts our model with weights already trained on another dataset. We then further tune the weights using our own data. A very common dataset to use as the starting point for fine-tuning is the ImageNet dataset. This dataset originally contained about 1 million images and 1,000 classes or tags. The breadth of image tags tends to make it a good dataset for fine-tuning.

Pytorch Lightning

Besides fine-tuning, there are other tricks we can apply to help our deep learning model train well on our data. For example, using a learning rate finder to pick the best learning rate.

Implementing all these best practices and keeping track of all the training steps can lead to a lot of code. To avoid all of this boilerplate, we are going to use Pytorch Lightning. I love this library. I find it really helps me to organize my Pytorch code well and avoid dumb mistakes such as forgetting to zero out my gradients.

We will use Pytorch Lightning by writing a class that implements the LightningModule. Here is the bulk of our code and then we’ll walk you through it:

class ImagenetTransferLearning(LightningModule):
   def __init__(self, hparams):
       super().__init__()
       # init a pretrained resnet
       self.hparams = hparams
       self.classifier = models.resnet50(pretrained=True)
       num_ftrs = self.classifier.fc.in_features
       self.classifier.fc = nn.Linear(num_ftrs,     self.hparams.num_target_classes)    def forward(self, x):
       return self.classifier(x)
   
   def training_step(self, batch, batch_idx):
       x, y = batch
       y_hat = self(x)
       loss = F.cross_entropy(y_hat, y)
       tensorboard_logs = {'train_loss': loss}
       return {'loss': loss, 'log': tensorboard_logs}    def configure_optimizers(self):
       return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
   
   def train_dataloader(self):
       train_transforms = transforms.Compose([
           transforms.RandomResizedCrop(224),
           transforms.RandomHorizontalFlip(),
           transforms.ToTensor(),
           transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
       dataset = datasets.ImageFolder(self.hparams.train_dir, train_transforms)
       loader = data.DataLoader(dataset, batch_size=self.hparams.batch_size, num_workers=4, shuffle=True)
       return loader
   
   def validation_step(self, batch, batch_idx):
       x, y = batch
       y_hat = self(x)
       loss = F.cross_entropy(y_hat, y)
       tensorboard_logs = {'val_loss': loss}
       return {'val_loss': loss, 'log': tensorboard_logs}
   
   def validation_epoch_end(self, outputs):
       avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
       tensorboard_logs = {'val_loss': avg_loss}
       return {'val_loss': avg_loss, 'log': tensorboard_logs}
   
   def val_dataloader(self):
       val_transforms = transforms.Compose([
           transforms.Resize(224),
           transforms.CenterCrop(224),
           transforms.ToTensor(),
           transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
       dataset = datasets.ImageFolder(self.hparams.val_dir, val_transforms)
       loader = data.DataLoader(dataset, batch_size=self.hparams.batch_size, num_workers=4)
       return loader

The first function we define is init(). This is the function we use to initialize our model. We start with the pre-trained Resnet50 from Pytorch and modify it slightly so that it predicts the appropriate number of classes. The number of classes or tags you want to predict is passed as part of hparams as num_target_classes.

Next, we have the forward() function. This one is simple, we just pass the input passed to it through our network.

Then we have the training_step() function. This function takes two inputs — the batch and the batch index. Within the function, all we need to define is what we want to happen during each training step. For this model, it is very simple. We pass the data through self() which is our neural network and then calculate the cross-entropy as our loss. For this function, it is standard to return a dictionary with the calculated loss as well as a log variable for Tensorboard. One of the great benefits of Pytorch Lightning is that if you do this, you get Tensorboard logging for basically free, which is super nice!

The configure_optimizers() function is used to define your optimizer. We will use the Adam optimizer and pass the learning rate via our hparams.

Lastly, for training, you have the train_dataloader() function. This is the function that takes care of loading your training data and passing it to your training step. We make sure to define our transforms to size the images and scale them in the same way our Resnet was pre-trained. We also apply some data augmentation with RandomResizedCrop() and RandomHorizontalFlip(). I then load the data with Pytorch’s ImageFolder() function. This function loads images from a folder as long as the folder follows the structure we defined previously. The data is passed to a DataLoader() which is what Pytorch uses to actually load the data. Within this function, we can define items such as the batch_size. We pass the batch_size as a hyper-parameter via hparams.

Since we also have validation data, we can define the exact same functions except they are not for the validation data: validation_step() and val_dataloader(). These functions are very similar. Some of the differences are we no longer do data augmentation and our step returns val_loss.

The validation section also has an additional function: validation_epoch_end(). This defines what should be done with the validation results at the end of an epoch. We just simply return the average validation loss. You can also do this for the training step as well if you wish.

Training

Now that we have done the heavy lifting of defining all the necessary steps, we can sit back and let Pytorch Lightning do it’s magic. First, let’s define our hyper-parameters (Pytorch Lightning expects it as an argparse Namespace):

hparams = Namespace(train_dir = <PATH TO YOUR TRAINING DIRECTORY>,
                  val_dir = <PATH TO YOUR VALIDATION DIRECTORY>,
                  num_target_classes = <NUMBER OF TAGS/CLASSES>,
                  lr = 0.001,
                  batch_size=8)

I set the batch size pretty small to work with pretty much any GPU.

Next, we initialize our model and train!

model = ImagenetTransferLearning(hparams)
trainer = Trainer(gpus=1,
                 early_stop_checkpoint=True,
                 auto_lr_find=True,
                 max_epochs=50
                )

The Trainer() is where the real magic happens. First, we tell it how many GPUs to train on, then we let it know to stop training early if the val_loss doesn’t improve, and one of the coolest options is auto_lr_finder. This tells the trainer to use an algorithm to find the best learning rate for our model and data and then use that rate instead of the rate we specified. Note: this only works if you pass hparams to your model and there is a lr value within your hparams. Lastly, to avoid running for too long we set the max_epochs to 50.

If you’ve done much deep learning you can appreciate the cleanliness of our Trainer(). We didn’t have to write a single loop over our data, its all taken care of for us. If we moved our code over to a machine with 8 GPUs, all we have to do is change gpus to 8. That’s it. If we got access to TPUs, Pytorch Lightning also supports those and you just turn on the option. At some point, you should definitely check out the docs on all the great options the Trainer() provides.

The Results

So — how well did our model do at tagging monkeys? Pytorch Lightning automatically checkpoints the model with the best validation results, which for me, happened at epoch 26. I loaded up that model with this code:

model = ImagenetTransferLearning.load_from_checkpoint(<PATH TO MODEL>)

And with this code, made predictions on all my validation data:

model.eval()
val_outs = []
truth_outs = []
for val_batch in tqdm(model.val_dataloader()):
   x, y = val_batch
   truth_outs.extend(y.numpy().tolist())
   val_out = model(x)
   val_outs.extend(val_out.detach().numpy().argmax(1).tolist())

Here is my classification report (using scikit-learn):

precision    recall  f1-score   support           0       0.89      0.92      0.91        26
          1       0.93      0.89      0.91        28
          2       1.00      0.93      0.96        27
          3       0.97      0.93      0.95        30
          4       1.00      0.88      0.94        26
          5       1.00      1.00      1.00        28
          6       1.00      1.00      1.00        26
          7       1.00      0.96      0.98        28
          8       0.93      1.00      0.96        27
          9       0.84      1.00      0.91        26   micro avg       0.95      0.95      0.95       272
  macro avg       0.96      0.95      0.95       272
weighted avg       0.96      0.95      0.95       272

Not bad! I was able to average a 0.95 f1-score and my lowest f1-score for a class was 0.91.

These are validation results, though, so they are very likely to be optimistic. To get a better representation of how well our model does we need to predict on images not in the training nor validation sets.

I didn’t take the time to create an entire test set, but I did grab 2 random images from Google of monkeys. In fact, those images are the 2 images at the top of this post. And our model was able to predict them both correctly!

Also, here are the Tensorboard graphs for the training:

Image for post

It’s safe to say that we are now, thanks to our deep learning model, experts on monkey species. :)

Go Do It Yourself!

The beautiful part is you can now easily go and classify any images you want. All you have to do is tag some of your own images, organize them appropriately (as discussed above), and change 3 lines of code.

hparams = Namespace(train_dir = <PATH TO YOUR TRAINING DIRECTORY>,
                  val_dir = <PATH TO YOUR VALIDATION DIRECTORY>,
                  num_target_classes = <NUMBER OF TAGS/CLASSES>,
                  lr = 0.001,
                  batch_size=8)

The only 3 lines you need to update are the values for train_dir, val_dir, and num_target_classes. That’s it! So — go do it for yourself and let me know what cool things you classify!