MultiGraph Transformer for FreeHand Sketch Recognition
Peng Xu
1
Chaitanya K. Joshi
1
Xavier Bresson
1
Abstract
Learning meaningful representations of freehand
sketches remains a challenging task given thesignal sparsity and the highlevel abstraction of sketches. Existing techniques have focused onexploiting either the static nature of sketcheswith Convolutional Neural Networks (CNNs) orthe temporal sequential property with RecurrentNeural Networks (RNNs). In this work, we pro
pose a new representation of sketches as multiple
sparsely connected graphs. We design a novelGraph Neural Network (GNN), the MultiGraph
Transformer (MGT), for learning representations
of sketches from multiple graphs which simultaneously capture global and local geometric
stroke structures, as well as temporal information.
We report extensive numerical experiments on asketch recognition task to demonstrate the per
formance of the proposed approach. Particularly,
MGT applied on 414k sketches from GoogleQuickDraw: (i) achieves small recognition gapto the CNNbased performance upper bound(
72
.
80%
vs.
74
.
22%
), and (ii) outperforms allRNNbased models by a signiﬁcant margin.To the best of our knowledge, this is the ﬁrstwork proposing to represent sketches as graphsand apply GNNs for sketch recognition. Codeand trained models are available at
https://github.com/PengBoXiangShang/multigraph_transformer
.
1. Introduction
Freehand sketches are drawings made without the use of
any instruments. Sketches are different from traditional images: they are formed of temporal sequences of strokes (Ha
& Eck , 2018; Xu et al., 2018), while images are static col
lections of pixels with dense color and texture patterns.Sketches capture highlevel abstraction of visual objects
1
School of Computer Science and Engineering, NanyangTechnological University, Singapore. Correspondence to:
Xavier Bresson
<
xbresson@ntu.edu.sg
>
.(a) srcinal sketch
(b) 1hop connected
(c) 2hop connected
Figure 1.
Sketches can be seen as sets of curves and strokes, which
are discretized by graphs.
Figure 2.
In sketchbased humancomputer interaction scenarios,it is timeconsuming to render and transfer pictures of sketches.Solely transferring stroke coordinates leads to realtime applica
tions.
with very sparse information compared to regular images,
which makes the modelling of sketches unique and challeng
ing.
The modern prevalence of touchscreen devices has led toa ﬂourishing of sketchrelated applications in recent years,including sketch recognition (Liu et al., 2019; Sarvadevab
hatla et al., 2016), sketch scene understanding (Ye et al.,
2016), sketch hashing (Xu et al., 2018), sketchbased image
retrieval (Sangkloy et al., 2016; Liu et al., 2017; Shen et al.,
2018; Collomosse et al., 2019; Dutta & Akata, 2019; Dey
et al., 2019), and sketchrelated generative models (Ha &
Eck , 2018; Chen & Hays, 2018; Lu et al., 2018; Liu et al.,
2019).
If we assume sketches to be 2D static images, CNNs can be
directly applied to sketches, such as “SketchaNet” (Yuet al., 2015). If we now suppose that sketches are or
dered sequences of point coordinates, then RNNs can beused to recursively capture the temporal information,
e
.
g
.,
“SketchRNN” (Ha & Eck , 2018).
In this work, we introduce a new representation of sketches
a r X i v : 1 9 1 2 . 1 1 2 5 8 v 1 [ c s . C V ] 2 4 D e c 2 0 1 9
MultiGraph Transformer for FreeHand Sketch Recognition
with
graphs
. We assume that sketches are sets of curvesand strokes, which are discretized by a set of points representing the graph nodes. This view offers high ﬂexibilityto encode different sketch geometric properties as we candecide different connectivity structures between the nodepoints. We use two types of graphs to represent sketches:
intrastroke graphs and extrastroke graphs. The ﬁrst graphscapture the local geometry of strokes, independently to each
other, with for example 1hop or 2hop connected graphs,
see Figure 1. The second graphs encode the global geometry
and temporal information of strokes. Another advantageof using graphs is the freedom to choose the node features.For sketches, spatial, temporal and semantic informationis available with the stroke point coordinates, the orderingof points, and the pen state information, respectively. Insummary, representing sketches with graphs offers a universal representation that can make use of global and localspatial sketch structures, as well as temporal and semantic
information.
To exploit these graph structures, we propose a new Trans
former (Vaswani et al., 2017) architecture that can use mul
tiple sparsely connected graphs. It is worth reporting thata direct application of the srcinal Transformer model on
the input spatiotemporal features provides poor results. We
argue that the issue comes from the graph structure in thesrcinal Transformer which is a fully connected graph. Al
though fullyconnected word graphs work impressively for
Natural Language Processing, where the underlying wordrepresentations themselves contain rich information, such
dense graph structures provide poor innate priors/inductive
bias (Battaglia et al., 2018) for 2D sketch tasks. Transform
ers require sketchspeciﬁc design coming from geometricstructures. This led us to naturally extend Transformersto multiple arbitrary graph structures. Moreover, graphs
provide more robustness to handle noisy and stylechangingsketches as they focus on the geometry of stokes and not on
the speciﬁc distribution of points.
Another advantage of using domainspeciﬁc graphs is toleverage the sparsity property of discretized sketches. Observe that intrastroke and extrastroke graphs are
highly
sparse
adjacency matrices. In practical sketchbased human
computer interaction scenarios, it is timeconsuming to directly transfer the srcinal sketch picture from user touch
screen devices to the backend servers. To ensure realtime
applications, transferring the stroke coordinates as a charac
ter string would be more beneﬁcial, see Figure 2.Our main contributions can be summarised as follows:
(i) We propose to model sketches as sparsely connectedgraphs, which are ﬂexible to encode local and global geometric sketch structures. To the best of our knowledge, itis the ﬁrst time that graphs are proposed for representing
sketches.(ii) We introduce a novel Transformer architecture that can
handle multiple arbitrary graphs. Using intrastroke and
extrastroke graphs, the proposed
MultiGraph Transformer
(MGT) learns both local and global patterns along sub
components of sketches.
(iii) This MultiGraph Transformer model is agnostic to
graph domains, and can be used beyond sketch applications.
(iv) Numerical experiments demonstrate the performancesof our model. MGT signiﬁcantly outperforms RNNbased
models, and achieves small recognition gap to CNNbased
architectures. This is promising for realtime sketchbasedhumancomputer interaction systems. Note that for sketchrecognition, CNNs are the performance upper bound of
coordinatebased models that involve truncating coordinate
sequences,
e
.
g
., RNN or Transformer based architectures.
2. Related Work
Neural Network Architectures for Sketches
CNNs area common choice for feature extraction from sketches.“SketchaNet” (Yu et al., 2015) was the ﬁrst CNNbased
model having a sketchspeciﬁc architecture. It was directly
inspired from AlexNet (Krizhevsky et al., 2012) with larger
ﬁrst layer ﬁlters, no layer normalization, larger poolingsizes, and high dropout. Song et al. (2017) further im
proved SketchaNet by adding spatialsemantic attentionlayers. “SketchRNN” (Ha & Eck , 2018) was a seminal
work to model temporal stroke sequences with RNNs. ACNNRNN hybrid architecture for sketches was proposed
in (Sarvadevabhatla et al., 2016).
In this work, we propose a novel Graph Neural Network architecture for learning sketch representations from multiple sparse graphs, combining both stroke geometry and
temporal order.
Graph Neural Networks
Graph Neural Networks (GNNs)
(Bruna et al., 2014; Defferrard et al., 2016; Sukhbaatar
et al., 2016; Kipf & Welling, 2017; Hamilton et al., 2017;
Monti et al., 2017) aim to generalize neural networks to non
Euclidean domains such as graphs and manifolds. GNNs
iteratively build representations of graphs through recursiveneighborhood aggregation (or message passing), where each
graph node gathers features from its neighbors to represent
local graph structure.
Transformers
The Transformer architecture (Vaswaniet al., 2017), srcinally proposed as a powerful and scal
able alternative to RNNs, has been widely adopted in the
Natural Language Processing community for tasks such as
machine translation (Edunov et al., 2018; Wang et al., 2019),
language modelling (Radford et al., 2018; Dai et al., 2019),
and questionanswering (Devlin et al., 2019; Yang et al.,
2019).
MultiGraph Transformer for FreeHand Sketch Recognition
Transformers for NLP can be regarded as GNNs whichuse selfattention (Bahdanau et al., 2014; Veli
ˇ
ckovi
´
c et al.,2018) for neighborhood aggregation on fullyconnectedword graphs (Ye et al., 2019). However, GNNs and Trans
formers perform poorly when sketches are modelled as fully
connected graphs. This work advocates for the injection of
inductive bias into Transformers through domainspeciﬁc
graph structures.
3. Method
3.1. Notation
We assume that the training dataset
D
consists of
N
labeled
sketches:
D
=
{
(
X
n
,z
n
)
}
N n
=1
. Each sketch
X
n
has aclass label
z
n
, and can be formulated as a
S
step sequence
[
C
n
,
f
n
,
p
]
∈
R
S
×
4
.
C
n
=
{
(
x
sn
,y
sn
)
}
S s
=1
∈
R
S
×
2
is thecoordinate sequence of the sketch points
X
n
. All sketchpoint coordinates have been uniformly scaled to
x
sn
,y
sn
∈
[0
,
256]
2
. If the true length of
C
n
is shorter than
S
thenthe vector
[
−
1
,
−
1]
is used for padding. Flag bit vector
f
n
∈ {
f
1
,f
2
,f
3
}
S
×
1
is a ternary integer vector that denotes
the pen state sequence corresponding to each point of
X
n
.
It is deﬁned as follows:
f
1
if the point
(
x
sn
,y
sn
)
is a starting
or ongoing point of a stroke,
f
2
if the point is the endingpoint of a stroke, and
f
3
for a padding point. Vector
p
=[0
,
1
,
2
,
···
,S
−
1]
T
is a positional encoding vector that
represents the temporal position of the points in each sketch
X
n
.
Given
D
, we aim to model
X
n
as multiple sparsely connected graphs and learn a deep embedding space, where
the highlevel semantic tasks can be conducted,
e
.
g
., sketch
recognition.
3.2. MultiModal Input Layer
Given a sketch
X
n
, we model its
S
stroke points as
S
nodes
of a graph. Each node has three features: (i)
C
sn
is thespatial positional information of the current stroke point
s
, (ii)
f
sn
is the pen state of the current stroke point. This
information helps to identify the stroke points belonging to
the same stroke, and (iii)
p
s
is the temporal information of
the current stroke point. As sketching is a dynamic process,
it is important to use the temporal information.
The complete model architecture for our MultiGraph Trans
former is presented in Figure 3. Let us start by describing
the input layer. The ﬁnal vector at node
s
of the multimodal
input layer is deﬁned as
(
h
sn
)
(
l
=0)
=
C
(
E
1
(
C
sn
)
,
E
2
(
f
sn
)
,
E
2
(
p
s
))
,
(1)
where
E
1
(
C
sn
)
is the embedding of
C
sn
with a linear layer
of size
2
×
ˆ
d
,
E
2
(
f
sn
)
and
E
2
(
p
s
)
are the embeddings of the
ﬂag bit
f
sn
(3 discrete values) and the position encoding
p
s
(
S
discrete values) from an embedding dictionary of size
Linear Coordinates Flag Bits Pos. Enc.Input Embedding: Embedding Lookup
Graph 2MHA
+Graph 1MHA
+Graph GMHA
+
Linear, ReLUConcatenate:
Batch Norm.Linear, ReLUDropoutDropout
+Batch Norm.Graph Embedding:
MLPSum Nodes:
Softmax:
MGMHAsublayer FF sublayer
Figure 3.
MultiGraph Transformer architecture. Each MGT layer
is composed of (i) a MultiGraph MultiHead Attention (MGMHA)
sublayer and (ii) a positionwise fully connected FeedForward
(FF) sublayer. See details in text. “B” denotes batch size.
(
S
+ 3)
×
ˆ
d
, and
C
(
·
,
·
)
is the concatenation operator. The
node vector
(
h
sn
)
(
l
=0)
has dimension
d
= 3ˆ
d
. The design of the input layer was selected after extensive ablation studies,
which are described in subsequent sections.
3.3. MultiGraph Transformer
The initial node embedding
(
h
sn
)
(
l
=0)
is updated by
stacking
L
MultiGraph Transformer (MGT) layers
(7)
. Let
us describe all layers.
Graph Attention Layer
Let
A
be a graph adjacency matrix of size
S
×
S
and
Q
∈
R
S
×
d
q
,
K
∈
R
S
×
d
k
,
V
∈
MultiGraph Transformer for FreeHand Sketch Recognition
Attention Layer
Graph
Dropout
Dropout Dropout
ℎ
ℎ
ℎ
ℎ
ℎ
Figure 4.
MultiHead Attention Layer, consisting of several Graph
Attention Layers in parallel.
R
S
×
d
v
be the query, key, and value matrices. We deﬁne a
graph attention layer asGraphAttention
(
Q
,
K
,
V
,
A
) =
A
⊙
softmax
(
QK
T
√
d
k
)
V
,
(2)
where
⊙
is the Hadamard product. We simply weight the
“Scaled DotProduct Attention” (Vaswani et al., 2017) with
the graph edge weights. We set
d
q
=
d
k
=
d
v
=
dI
, where
I
is the number of attention heads.
MultiHead Attention Layer
We aggregate the graph at
tentions with multiple heads:MultiHead
(
Q
,
K
,
V
,
A
) =
C
(
head
1
,
···
,
head
I
)
W
O
,
(3)
where
W
O
∈
R
Id
v
×
d
and each attention head is computed
with the graph attention layer (2):head
i
=
GraphAttention
(
QW
Qi
,
KW
K i
,
VW
V i
,
A
)
,
(4)
where
W
Qi
∈
R
d
×
d
q
,
W
K i
∈
R
d
×
d
k
, and
W
V i
∈
R
d
×
d
v
.We add dropout (Srivastava et al., 2014) before the
linear projections of
Q
,
K
and
V
. An illustration of the
MultiHead Attention Layer is presented in Figure 4.
MultiGraph MultiHead Attention Layer
Given a set
of adjacency graph matrices
{
A
g
}
Gg
=1
, we can concatenate
MultiHead Attention Layers:MultiGraphMultiHeadAttention
(
Q
,
K
,
V
,
{
A
g
}
Gg
=1
) =
ReLU
(
C
(
ghead
1
,
···
,
ghead
G
)
W
O
)
,
(5)
where
W
O
∈
R
Gd
×
d
and each MultiHead Attention Layer
is computed with (3):ghead
g
=
MultiHead
(
Q
,
K
,
V
,
A
g
)
.
(6)
MultiGraph Transformer Layer
The MultiGraph
Transformer (MGT) at layer
l
for node
s
is deﬁned as
(
h
sn
)
(
l
)
=
MGT
((
h
n
)
(
l
−
1)
)= ˆ
h
sn
+
FF
(
l
)
(ˆ
h
sn
)
,
(7)
where the intermediate feature representation
ˆ
h
sn
is deﬁned
as:
ˆ
h
sn
= (
MGMHA
sn
)
(
l
)
((
h
1
n
)
(
l
−
1)
,
···
,
(
h
S n
)
(
l
−
1)
)
.
(8)
The MGT layer is thus composed of (i) a MultiGraph MultiHead Attention (MGMHA) sublayer
(5)
and (ii) a position
wise fully connected FeedForward (FF) sublayer. Each
MHA sublayer
(6)
and FF
(7)
has residualconnection (He
et al., 2016) and batch normalization (Ioffe & Szegedy,
2015). See Figure 3 for an illustration.
3.4. Sketch Embedding and Classiﬁcation Layer
Given a sketch
X
n
with
t
n
key points, its continuous rep
resentation
h
n
is simply given by the sum over all its node
features from the last MGT layer:
h
n
=
t
n
s
=1
(
h
sn
)
(
L
)
.
(9)
Finally, we use a MultiLayer Perceptron (MLP) to classify
the sketch representation
h
n
, see Figure 3.
3.5. SketchSpeciﬁc Graphs
In this section, we discuss the graph structures we used inour Graph Transformer layers. We considered two types
of graphs, which capture local and global geometric sketch
structures.
The ﬁrst class of graphs focus on representing the localgeometry of individual strokes. We choose
K
hop graphsto describe the local geometry of strokes. The intrastroke
adjacency matrix is deﬁned as follows:
A
K
hop
n,ij
=
1
if
j
∈ N
K
hop
i
and
j
∈
global
(
i
)
,
0
otherwise
,
(10)
where
N
K
hop
i
is the Khop neighborhood of node
i
and
global
(
i
)
is the stroke of node
i
.
The second class of graphs capture the global and temporal relationships between the strokes composing the whole
MultiGraph Transformer for FreeHand Sketch Recognition
Table 1.
Summary statistics for our subset of QuickDraw.
Set # Samples # Truncated (ratio) # Key Pointsmax min mean stdTraining 345,000 11788 (3.42%) 100 2 43.26 21.85Validation 34,500 1218 (3.53%) 100 2 43.24 21.89Test 34,500 1235 (3.58%) 100 2 43.20 21.93
sketch. We deﬁne the extrastroke adjacency matrix as fol
lows:
A
global
n,ij
=
1
if

i
−
j

= 1
and global
(
i
)
=
global
(
j
)
,
0
otherwise
.
(11)
This graph will force the network to pay attention between
two points belonging to two distinct strokes but consecutive
in time, thus allowing the model to understand the relative
arrangement of strokes.
4. Experiments
4.1. Experimental SettingDataset and PreProcessing
Google QuickDraw (Ha &Eck , 2018)
1
is the largest available sketch dataset containing 50 Million sketches as simpliﬁed stroke key points intemporal order, sampled using the RamerDouglasPeucker
algorithm after uniformly scaling image coordinates within
0
to
256
. Unlike smaller crowdsourced sketch datasets,
e
.
g
., TUBerlin (Eitz et al., 2012), QuickDraw samples were
collected via an international online game where users haveonly 20 seconds to sketch objects from 345 classes, such ascats, dogs, clocks,
etc
. Thus, sketch classiﬁcation on Quick
Draw not only involves a diversity of drawing styles, but
can also be highly abstract and noisy, making it a challeng
ing and practical testbed for comparing the effectivenessof various neural network architectures. Following recentpractices (Dey et al., 2019; Xu et al., 2018), we create ran
dom training, validation and test sets from the full dataset
by sampling
1000
,
100
and
100
sketches respectively from
each of the 345 categories in QuickDraw. Following (Xuet al., 2018), we truncate or pad all samples to a uniform
length of 100 key points/steps to facilitate efﬁcient training
of RNN and GNNbased models. We provide summary
statistics for our training, validation and test sets in Table 1,
and histograms visualizing the key points per sketch are
shown in Figure 5.
Evaluation Metrics
Our evaluation metric for sketchrecognition is “top K accuracy”, the proportion of sam
ples whose true class is in the top K model predictions, for
values
k
= 1
,
5
,
10
. (Note that acc.@k
= 1
.
0
means 100%)
Implementation Details
For fair comparison under simi
1
https://quickdraw.withgoogle.com/data
020406080100 KeyPoints 0 2000 4000 6000 8000 10000 12000 14000
S k e t c h A m o u n t
(a) Training
020406080100 KeyPoints 0 200 400 600 800 1000 1200 1400
S k e t c h A m o u n t
(b) Validation
020406080100 KeyPoints 0 200 400 600 800 1000 1200 1400
S k e t c h A m o u n t
(c) Test
Figure 5.
Histograms of key points per sketch for our subset of
QuickDraw. The sharp spike at
100
key points is due to truncation.
lar hardware conditions, all experiments were implemented
in PyTorch (Paszke et al., 2019) and run on one Nvidia
1080Ti GPU. For Transformer models, we use the following hyperparameter values:
S
= 100
,
L
= 4
,
ˆ
d
= 128
,
G
= 3
(
A
1
hop
,
A
2
hop
,
A
global
), and
I
= 8
(per graph) forour Base model (and
ˆ
d
= 256
for our Large model). OurFF sublayer is a
d
dimensional linear layer (
d
= 3ˆ
d
) fol
lowed by ReLU (Glorot et al., 2011) and dropout. The MLP
Classiﬁer consists of two
4ˆ
d
dimensional linear layers with
ReLU and dropout, followed by a
345
dimensional linearprojection representing logits over the 345 categories in
QuickDraw. We train all models by minimizing the softmax
crossentropy loss using the Adam (Kingma & Ba, 2014)
optimizer for
100
epochs. We use an initial learning rateof
5
e
−
5
and multiply by a factor
0
.
7
every
10
epochs. Weuse an early stopping strategy (with the hyperparameter“patience” of 10 epochs) for selecting the ﬁnal model, andthe checkpoint with the highest validation performance is
chosen to report test performance.
Baselines
(i) From the perspective of coordinatebased
sketch recognition, RNN models are a simpleyeteffective
baseline. Following Xu et al. (2018), we design several
bidirectional LSTM (Hochreiter & Schmidhuber, 1997)
and GRU (Cho et al., 2014) models at increasing parameter
budgets comparable with MGT. The ﬁnal RNN states areconcatenated and passed to the MLP classiﬁer described
previously. We use batch size
256
, initial learning rate
1
e
−
4
and multiply by
0
.
9
every
10
epochs. We train models with
both our multimodal input (Section 3.2) as well as the 4D
input from Xu et al. (2018).
(ii) Although converting sketch coordinates to images adds
time overhead in practical settings and can be seen as auxi
lary information, we compare MGT to various stateofthe
art CNN architectures. It is important to note that sketchsequences were truncated/padded for training both MGTand RNNs, hence imagebased CNNs stand as an upper
bound in terms of performance. For Inception V3 (Szegedyet al., 2016) and MobileNet V2 (Sandler et al., 2018), initial
learning rate is
1
e
−
3
and multiplied by
0
.
5
every
10
epochs.For other CNN baselines, the initial learning rate and decay
are conﬁgured following their srcinal papers. For each
model, we use the maximum possible batch size. Followingstandard practice in computer vision (He et al., 2016; Huang