Seeing through the eyes of neural networks

with FlashTorch


Misa Ogura

Research Software Engineer @ BBC R&D

📡 Slide deck @ tinyurl.com/flashtorch-ldn-meetup

Hello 👋


  • From Tokyo, based in London
  • Cancer Cell Biologist
  • Research Software Engineer
  • Co-founder & CTO of Women Driven Development

Convolutional neural network (CNN)

Breakthrough in computer vision


Convolution

Image processing technique


Detects features such as edges and curves.

Features are pre-defined.

CNN

Convolution + deep neural network


Useful features are learned during training, rather than pre-defined.

How can we explain output from the CNN?

Feature visualization

"Looking through" the eyes of neural nets


How do CNNs perceive images?


Brilliant series of articles on Distill.

Saliency maps

What is most noticeable/important?


  • Saliency: a subjective quality in human visual perception

  • Makes certain regions stand out

Saliency maps in computer vision highlight salient regions.

Saliency maps for CNN

Gives some intuition of attention


  • First Introduced in 2013

  • Calculates the relationship (or gradients) between output label & input image pixels

  • Positive gradients == positive effects on the confidence of the label

Introducing FlashTorch

Open source feature visualization toolkit


  • Supports torchvision models & custom PyTorch models

  • Available to install via pip

      $ pip install flashtorch

FlashTorch demo 1

Visualizing saliency maps


Prepare an input 🐦


In [2]:
from flashtorch.utils import load_image, apply_transforms, denormalize, format_for_plotting

image = load_image('../../examples/images/great_grey_owl.jpg')

owl = apply_transforms(image)

print(f'Before: {type(image)}')
print(f'After: {type(owl)}, {owl.shape}')

plt.imshow(format_for_plotting(denormalize(owl)))
plt.title('Input tensor')
plt.axis('off');
Before: <class 'PIL.Image.Image'>
After: <class 'torch.Tensor'>, torch.Size([1, 3, 224, 224])

Create a Backprop object


In [3]:
from flashtorch.saliency import Backprop

model = models.alexnet(pretrained=True)

backprop = Backprop(model)
  • Registers custom functions to model layers
  • Grabs gradients out of the computational graph

Visualize saliency maps


In [4]:
from flashtorch.utils import ImageNetIndex 

imagenet = ImageNetIndex()
target_class = imagenet['great grey owl']

backprop.visualize(owl, target_class, guided=True)
Pixels around the head and eyes have the strongest positive effects.
What about other birds?

What makes peacock a peacock?


In [5]:
peacock = apply_transforms(load_image('../../examples/images/peacock.jpg'))
target_class = imagenet['peacock']

backprop.visualize(peacock, target_class, guided=True)

... or a toucan?


In [6]:
toucan = apply_transforms(load_image('../../examples/images/toucan.jpg'))
target_class = imagenet['toucan']

backprop.visualize(toucan, target_class, guided=True)
Do you agree? 🤖

What about when it gets things wrong?
Is FlashTorch still useful?

FlashTorch demo 2

When neural nets get things wrong


Transfer learning

A form of knowledge transfer


A model developed for one task is reused as a starting point for another.

Building a flower classifier


  1. Get a model, pre-trained with the ImageNet dataset (1000 classes)
  2. Swap out the last layer (102 classes) --> Un-tuned
  3. Train with a flower dataset --> Tuned

Un-tuned model: 0.1% accuracy on the test dataset 😅

Why is it so bad? Let's take a look.

In [7]:
plt.imshow(load_image('images/foxgloves.jpg'))
plt.title('Foxgloves')
plt.axis('off');

Un-tuned model

Test accuracy: 0.1%


In [10]:
backprop = Backprop(untuned_model)

backprop.visualize(foxglove, class_index, guided=True)
/Users/misao/Projects/personal/flashtorch/flashtorch/saliency/backprop.py:111: UserWarning: The predicted class index 98 does notequal the target class index 96. Calculatingthe gradient w.r.t. the predicted class.
  'the gradient w.r.t. the predicted class.'
The network doesn't really know what to look for.

After training, the tuned model achieved >98% test accuracy 🎉
But how? What is it seeing now? 🤔

Tuned model

Test accuracy: >98%


In [11]:
backprop = Backprop(tuned_model)

backprop.visualize(foxglove, class_index, guided=True)
The model has learnt to focus on to the most distinguising pattern.

Show, don't tell - visualization is a powerful tool.

Model explainability


With feature visualization, we're better equipped to:

  • Focus on the mechanisms of how & what neural nets learn

  • Diagnose what and why the network gets things wrong

  • Spot and correct biases in algorithms

A step forward towards understanding & trusting AI 😌

Thank you 🙏


📡 Slide deck @ tinyurl.com/flashtorch-ldn-meetup