Image Classification with Localization

In image classification, we feed input image to a convolutional neural network and it gives back a feature vector (fully connected layer). The feature vector is then injected to softmax layer to get prediction of a class. Let's say we are building an image classifier for self-driving car application with a set of four possible classes;

  • $c_{1}$ pedestrian
  • $c_{2}$ car
  • $c_{3}$ motorcycle
  • $c_{4}$ background

In this case, softmax layer will have 4 units or outputs which comes down to a standard image classification pipeline.

How about if we also have to localize object in image. For that softmax will have four addition output numbers, $b_{x}$, $b_{y}$, $b_{h}$, $b_{w}$, parameterizing the bounding box around the object.

Bounding box coordinates are defined according to following convention:

  • upper left of image (0,0)
  • lower right (1,1).
  • middle point of bounding box defined by $b_{x}$, $b_{y}$
  • width of bounding box $b_{w}$
  • height of bounding box: $b_{h}$

In training set, each example contains not only class label but also four additional bounding box numbers ($b_{x}$, $b_{y}$, $b_{h}$, $b_{w}$). Now, our convolution neural network will learn to predict both class of object as well as its bounding box.

Target $y$ label vector will consist of 8 components;

  • $p_{c}$ - is there any object
  • $b_{x}$ - x coordinate of mid point of BB
  • $b_{y}$ - y coordinate of mid point of BB
  • $b_{w}$ - width of BB
  • $b_{h}$ - height of BB
  • $c_{1}$ pedestrian
  • $c_{2}$ car
  • $c_{3}$ motorcycle

If images has a car and $c_{2}$ represents car class then $p_{c}$ and $c_{2}$ will be 1 and bounding box coordinates in $b_{x}$, $b_{y}$, $b_{w}$ and $b_{h}$.

If image has not object, then $p_{c}$ will be 0 and rest of outputs to be ? meaning "don't care"

The loss function to train neural network for classification with localization would look like following;

$$\int(\hat{y}, y) = (\hat{y}_{1}- y_{1})^2 + (\hat{y}_{2}-y_{2})^2 + ... + (\hat{y}_{8} - y_{8})^2$$

Got A Data Science Question?

Ask our experts anything about machine learning, analytics or statistics.