if 'google.colab' in str(get_ipython()): # google colab specific setup
!git clone https://github.com/tensorturtle/deep-sinusoidal-grating.git
%cd deep-sinusoidal-grating
import torch
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
DotDistortionsDataset
Tutorial¶by Jason Sohn
Selectively distorted dot patterns are a classic stimuli scheme1 that has been used in cognitive psychology research for many decades. It has been instrumental in establishing many core concepts like 'prototypes'2, among others.
I have implemented the dot pattern generation method proposed by Smith et al. (2005)3 from scratch in Python for easy integration with PyTorch. I hope that cross-pollination efforts like this one will catalyze exciting research at the intersection of psychology and machine learning (computer science).
By the end of this tutorial, you will have a solid understanding of all the 'buttons and levers' of this dataset, and be able to train a basic neural network on it.
The original code is supposedly written in Turbo Pascal (a programming language of the ancient Babylonians or something...).
POSNER, MICHAEL I.; GOLDSMITH, RALPH; WELTON, KENNETH E., Jr. (1967). PERCEIVED DISTANCE AND THE CLASSIFICATION OF DISTORTED PATTERNS.. Journal of Experimental Psychology, 73(1), 28–38. doi:10.1037/h0024135↩
POSNER, MICHAEL I.; KEELE, STEVEN W. (1968). ON THE GENESIS OF ABSTRACT IDEAS.. Journal of Experimental Psychology, 77(3, Pt.1), 353–363. doi:10.1037/h0025953↩
Smith, J. David; Redford, Joshua S.; Gent, Lauren C.; Washburn, David A. (2005). Visual Search and the Collapse of Categorization. Journal of Experimental Psychology: General, 134(4), 443–460. doi:10.1037/0096-3445.134.4.443↩
For our first demonstration, we will create 16 examples of a single category.
from DotDistortionsDataset import DotDistortions
dataset = DotDistortions(
length = 16,
train_like = True, # single shape in image, with no distractors
category_seeds = [8], # pick a favorite number
num_categories = 1,
)
Since we're starting from scratch, we must call produce()
on the dataset. This will actually generate the dataset according to the parameters we set above, and save it so that we can load it later.
dataset.produce('temp/mini_dataset.pkl')
dataloader = iter(dataset) # iter() turns a dataset into an iterator, which allows us to fetch consecutive data points using next()
image, bboxes, labels = next(dataloader)
Let's see the first image in this dataset:
import matplotlib.pyplot as plt
plt.imshow(image, cmap='gray')
<matplotlib.image.AxesImage at 0x7f599d75b9a0>
Bounding box defines the outermost points on the shape:
bboxes
(30, 39, 120, 111)
Label refers to the category. This dataset will always create the same shape for any given label number. Technically speaking, the label (category) number seeds the random number generator.
labels
8
Let's see the whole dataset:
# just a little function to show images in a grid
def show_bunch(
dataset,
rows = 2,
x_figsize=20,
y_figsize=5,):
dataset.produce()
dataloader=iter(dataset)
fig = plt.figure(figsize=(x_figsize,y_figsize))
for i in range(len(dataset)):
image, bboxes, labels = next(dataloader)
ax = fig.add_subplot(rows, len(dataset)//rows, i+1)
ax.set_title(str(labels))
# no ticks
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(image, cmap='gray')
dataset = DotDistortions(
length = 16,
train_like = True, # single shape in image, with no distractors
category_seeds = [8], # pick a favorite number
num_categories = 1,
)
show_bunch(dataset)
Notice how the above shapes are all quite similar to each other. It would be pretty easy to learn these shapes from the images.
What happens if we increase the distortion level?
dataset = DotDistortions(
length = 8,
train_like = True, # single shape in image, with no distractors
category_seeds = [8], # pick a favorite number
num_categories = 1,
distortion_level = '7.7' # MAXIMUM distortion!
)
show_bunch(dataset, rows=1, x_figsize=20, y_figsize=2.5)
With a high distortion level of '7.7', it becomes harder to see the commonality between the shapes.
Choose from: '1','2','3','4,'5','6,'7.7'
.
Just for fun, let's
import ipywidgets
dataset = DotDistortions(
length = 16,
train_like = True,
category_seeds = [2**11, 314, 777], # pick three favorite numbers
num_categories = 3, # match the number of categories
distortion_level = '1' # no distortion
)
show_bunch(dataset, rows=2, x_figsize=20, y_figsize=5)
Supplying the same number into category_seeds
will always generate the same basic shapes. The order and shading type will always be random.
In the reference literature, subjects were trained to identify single-shape images (shown above). Then, they were shown an image containing 7 mixed shapes containing 0 or 1 shape from the trained categories. Participants were asked to select whether they saw a trained shape in the test image.
For example, let's say that one of the shapes we were trained to identify was:
dataset = DotDistortions(
length = 1,
train_like = True, # single shape in image, with no distractors
category_seeds = [52], # pick a favorite number
num_categories = 1,
)
show_bunch(dataset, rows=1, x_figsize=5, y_figsize=5)
Do you see this shape in each of the test images below?
dataset = DotDistortions(
length = 4,
train_like = False, # test mode
category_seeds = [52],
num_categories = 1,
distortion_level = '1', # no distortion; easy mode
test_like_exists_probability=1.0, # test image will certainly contain a shape of trained category
total_shapes = 7, # total number of shapes in test image
)
show_bunch(dataset, rows=2, x_figsize=20, y_figsize=15)
The answer is YES!
Note: The category numbers of random (non-trained) shapes are all '-1'.
Here are some more examples of test images, but now some do not contain the shape we were trained to identify.
dataset = DotDistortions(
length = 4,
train_like = False, # test mode
category_seeds = [52],
num_categories = 1,
distortion_level = '1', # no distortion; easy mode
test_like_exists_probability=0.25, # a quarter of the test images will contain a shape of trained category
total_shapes = 7, # total number of shapes in test image
)
show_bunch(dataset, rows=2, x_figsize=20, y_figsize=15)