Paper Notes: Generative Teaching Networks
Generative Teaching Networks
Overview
 Training data is a limitation of ML/AI  limited size, inherent biases. This paper explores the possibility that generating data might be the answer to the limitations of data.
 Notably the generation of data may also lead to other desirable attributes such as
 Faster evaluation of architectures in NAS
 Better generalizability
Problem Domain
 Applications in Reinforcement Learning, Supervised Learning, Neural Architecture Search
Links
Notes
Unlike a GAN, two models cooperate instead of compete. The learning process consists of two loops: an inner loop and an outer loop trained together with metalearning via nested optimization.
Inner loop
 Generator takes Gaussian Noise and a label as input and produces a sample. (Could also generate the label)
 Learner is trained for a fixed number of steps (SGD with momentum)
 Authors also learn from specific curriculums in the inner loop (3 total) which are parameterized as a part of the entire network
 Loss: MSE for regression, crossentropy for classification
 inner loop objective does not depend on outer loop objective
Outer Loop
 Learner is then trained on real data which constructs the metatrainingloss
 Gradient of the metatraining loss is backpropagated through the entire network
 Hyperparameters of the learner as also parameterized and included in gradient calculation.
Claims
 Weight normalization improves stability of meta learning
 Curriculum learning is useful for GTN
 Faster NN training times for CIFAR10 and MNIST and cart pole than compared to training with real data.
 Faster NAS for CIFAR10 with limited compute
Claim 1: Weight Normalization improves stability of metalearning
The claim is that weight normalization is analogous to batch normalization. Authors report that weight normalization greatly increases the stability of the training process.^{1}
# weight normalization layers in Pytorch
def make_weight_norm_layer(base_cls):
class WeightNorm(base_cls):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.weight_g = nn.Parameter(torch.norm_except_dim(self.weight, 2, 0).data)
def forward(self, x):
x = super().forward(x)
x = x * (self.weight_g / torch.norm_except_dim(self.weight, 2, 0)).transpose(1, 0)
return x
return WeightNorm
class Linear(nn.Linear):
def reset_parameters(self):
super().reset_parameters()
torch.nn.init.kaiming_normal_(self.weight)
from exconv import ExConv2d
class Conv2d(nn.Conv2d):
def reset_parameters(self):
super().reset_parameters()
torch.nn.init.kaiming_normal_(self.weight)
def forward(self, x):
if self.bias is None and self.groups == 1 and not getattr(self, "disable_exconv", False):
return ExConv2d.apply(x, self.weight, self.padding, self.stride, self.dilation, self.groups)
else:
return super().forward(x)
WeightNormLinear = make_weight_norm_layer(Linear)
WeightNormConv2d = make_weight_norm_layer(Conv2d)
Claim 2: Curriculum Learning is useful for GTN
The authors report that curriculum learning is is useful for performance.
 4 curriculums studied
 No curriculum
 trains a generator to output synthetic training data by sampling a noise vector from a Gaussian distribution
 Can theoretically generate infinite training data, but performs the worst of all learning measures
 All shuffled
 learns fixed set of 4096 inputs presented in a random order without replacement
 e.g. learning controls data but not order or batch makeup
 Shuffled Batch
 learns 32 batches fo 128 inputs each
 e.g. learns to control what’s in each batch but not order
 Full Curriculum
 learns deterministic sequence of 32 batches of 128 samples
 e.g. controls order and content of batches
 No curriculum
class CGTN(nn.Module):
def __init__(self, generator, num_inner_iterations, generator_batch_size, noise_size, evenly_distributed_labels=False, meta_learn_labels=False):
super().__init__()
self.generator_batch_size = generator_batch_size
self.generator = generator
if evenly_distributed_labels:
labels = torch.arange(num_inner_iterations * generator_batch_size) % 10
labels = torch.reshape(labels, (num_inner_iterations, generator_batch_size))
self.curriculum_labels = nn.Parameter(labels, requires_grad=False)
else:
self.curriculum_labels = nn.Parameter(torch.randint(10, size=(num_inner_iterations, generator_batch_size), dtype=torch.int64), requires_grad=False)
self.curriculum_labels_one_hot = torch.zeros(num_inner_iterations, generator_batch_size, 10)
self.curriculum_labels_one_hot.scatter_(2, self.curriculum_labels.unsqueeze(1), 1)
self.curriculum_labels_one_hot = nn.Parameter(self.curriculum_labels_one_hot, requires_grad=meta_learn_labels)
# TODO: Maybe learn the softlabels?
self.curriculum = nn.Parameter(torch.randn((num_inner_iterations, generator_batch_size, noise_size), dtype=torch.float32))
self.generator = torch.jit.trace(self.generator, (torch.rand(generator_batch_size, noise_size + 10),))
def forward(self, it):
label = self.curriculum_labels_one_hot[it]
noise = torch.cat([self.curriculum[it], label], dim=1)
x = self.generator(noise)
if not x.requires_grad:
label = label.detach()
return x, label
class CGTNAllShuffled(CGTN):
def forward(self, it):
all_images = torch.reshape(self.curriculum, (1,) + self.curriculum.shape[2:])
all_labels = torch.reshape(self.curriculum_labels_one_hot, (1,) + self.curriculum_labels_one_hot.shape[2:])
idx = torch.randint(len(all_images), size=(self.generator_batch_size,), device=all_images.device)
noise = all_images[idx]
labels = all_labels[idx]
noise = torch.cat([noise, labels], dim=1)
x = self.generator(noise)
return x, labels
class CGTNBatchShuffled(CGTN):
def forward(self, it):
idx = torch.randint(len(self.curriculum), size=())
noise = self.curriculum[idx]
labels = self.curriculum_labels_one_hot[idx]
noise = torch.cat([noise, labels], dim=1)
x = self.generator(noise)
return x, labels
Claim 3: Faster NN training times for CIFAR, MNIST, CartPole
 GTN for supervised learning
 Experiment setup:
 Data: MNIST images or GTN generated MNIST images
 Hyperparameters:
 3 setups trained
 Real data: learners trained with real samples randomly sampled
 Dataset Distillation: training learners with synthetic data where training examples are encoded directly as tensors optimized by the metaobjective (from Wang et al 2019b)
 GTN: Authors method
 Results
 GTN outperforms benchmarks for MNIST
 Author notes that the SOTA performance is not important but rather how well and inexpensively the model is trained itself. This is important for NAS. Provides the ability to identify and evaluate architectures quickly
 Very interestingly, the pictures generated by the GTN are not obviously pictures of images.^{2}
# setup 3
class GTN(nn.Module):
def __init__(self, generator, generator_batch_size, noise_size):
super().__init__()
self.generator = generator
self.generator_batch_size = generator_batch_size
self.noise_size = noise_size
self.generator = torch.jit.trace(self.generator, (torch.rand(generator_batch_size, noise_size + 10),))
def forward(self, it):
curriculum_labels = torch.randint(10, size=(self.generator_batch_size,), dtype=torch.int64, device="cuda")
curriculum_labels_one_hot = torch.zeros(self.generator_batch_size, 10, device="cuda")
curriculum_labels_one_hot.scatter_(1, curriculum_labels.unsqueeze(1), 1)
curriculum_labels_one_hot = curriculum_labels_one_hot.to("cuda")
noise = torch.cat([torch.randn(self.generator_batch_size, self.noise_size, device="cuda"), curriculum_labels_one_hot], dim=1)
x = self.generator(noise)
return x, curriculum_labels_one_hot
# Setup 2
class DatasetDistillation(nn.Module):
def __init__(self, num_inner_iterations, generator_batch_size, img_shape):
super().__init__()
self.curriculum_labels = nn.Parameter(torch.randint(10, size=(num_inner_iterations, generator_batch_size), dtype=torch.int64), requires_grad=False)
self.curriculum_labels_one_hot = torch.zeros(num_inner_iterations, generator_batch_size, 10)
self.curriculum_labels_one_hot.scatter_(2, self.curriculum_labels.unsqueeze(1), 1)
self.curriculum_labels_one_hot = nn.Parameter(self.curriculum_labels_one_hot, requires_grad=False)
self.curriculum = nn.Parameter(torch.randn((num_inner_iterations, generator_batch_size,) + img_shape, dtype=torch.float32))
def forward(self, it):
x = self.curriculum[it]
return torch.tanh(x) * 2, self.curriculum_labels_one_hot[it]
Claim 4: Faster NAS with limited compute
 Author claims that GTNs are applicable for NAS because they reduce the computational cost of evaluation
 In other words, the GTN is evaluated on its effectiveness as a proxy evaluation of candidate architectures.
 the authors measured the effectiveness of GTN’s ability to reduce training cost by calculating the Spearman rank correlation between 128 steps of training on GTN synthetic data vs 100 epochs of real data
 Correlation: 0.3606 or .5582 if considering top 50% of architectures
 Experiment Setup:
 Data: CIFAR10
 Benchmark: NAO ^{3}
 Benchmark Code: NAO benchmark code
 Results
 GTNNAS improves the NAS state of the art by finding higherperforming architectures than comparable methods like weight sharing and Graph HyperNetworks^{4}
# NAS setup used in the paper
class AutoML(nn.Module):
def __init__(self, generator, optimizers, initial_batch_norm_momentum=0.9):
super().__init__()
self.generator = generator
self.optimizers = torch.nn.ModuleList(optimizers)
self.batch_norm_momentum_logit = nn.Parameter(torch.as_tensor(inner_optimizers.inv_sigmoid(0.9)))
@property
def batch_norm_momentum(self):
return torch.sigmoid(self.batch_norm_momentum_logit)
def sample_learner(self, input_shape, device, allow_nas=False, learner_type="base",
iteration_maps_seed=False, iteration=None, deterministic=False, iterations_depth_schedule=100, randomize_width=False):
if iteration_maps_seed:
iteration = iteration  1
encoding = [iteration % 6, iteration // 6]
else:
encoding = None
if learner_type == "sampled":
layers = min(4, max(0, iteration // iterations_depth_schedule))
model, encoding = sample_model(input_shape, layers=layers, encoding=encoding, blocks=2,
seed=iteration if deterministic else None, batch_norm_momentum=0)
tlogger.record_tabular("encoding", encoding)
elif learner_type == "sampled4":
model, encoding = sample_model(input_shape, layers=4, encoding=encoding, seed=iteration if deterministic else None, batch_norm_momentum=0)
tlogger.record_tabular("encoding", encoding)
elif learner_type == "base":
model = Classifier(input_shape, batch_norm_momentum=0.0, randomize_width=randomize_width)
elif learner_type == "base_fc":
model = Classifier(input_shape, batch_norm_momentum=0.0, randomize_width=randomize_width, use_global_pooling=False)
elif learner_type == "linear":
model = models.LinearClassifier(input_shape)
elif learner_type == "base_larger":
model = models.ClassifierLarger(input_shape, batch_norm_momentum=0.0, randomize_width=randomize_width)
elif learner_type == "base_larger2":
model = models.ClassifierLarger2(input_shape, batch_norm_momentum=0.0, randomize_width=randomize_width)
elif learner_type == "base_larger3":
model = models.ClassifierLarger3(input_shape, batch_norm_momentum=0.0, randomize_width=randomize_width)
elif learner_type == "base_larger3_global_pooling":
model = models.ClassifierLarger3(input_shape, batch_norm_momentum=0.0, randomize_width=randomize_width, use_global_pooling=True)
elif learner_type == "base_larger4_global_pooling":
model = models.ClassifierLarger4(input_shape, batch_norm_momentum=0.0, randomize_width=randomize_width, use_global_pooling=True)
elif learner_type == "base_larger4":
model = models.ClassifierLarger4(input_shape, batch_norm_momentum=0.0, randomize_width=randomize_width, use_global_pooling=False)
else:
raise NotImplementedError()
return Learner(model=model.to(device), optimizer=np.random.choice(self.optimizers)), encoding
Important Conclusions
 Weight normalization can be used to stabilize metatraining in many different environments
 Many possible extensions of GTNs
 Unsupervised learning for creating a useful embedding function
 stabilize GAN training and prevent mode collapse
 Introducing a closed loop curriculum that is able to adapt to the performance of the learner.
 i.e. in the beginning, learning would be easier and progressively get more complex as the learner’s performance improved. This is analogous to a tutor’s approach to teaching a student. They don’t just jump straight to calculus, they teach algebra first.
 The generated data does not need to resemble the training data in order to achieve good results. There are many examples in the paper about why. See appendix.
 GTNs have the ability to bootstrap learners in complex ways such that many aspects of learning can be combined. By controlling the data that is ingested by the learner, the learning itself is controlled.
Future Directions
they also one day could be a key to creating AIgenerating algorithms, which seek to bootstrap themselves from simple initial conditions to powerful forms of AI by creating an openended stream of challenges (learning opportunities) while learning to solve them (Clune, 2019).
 Bootstrapping learners from simple initial conditions
 Creating NN “ondemand”
Implementation Notes
Some notes on how the GTN is implemented. Code can be found here > Codebase
Stack
 Language: Python/C++
 DL Framework: Pytorch
 Parallelism: Horovod/MPI
Another notable point about the implementation is the use of the Fire library which is very helpful in automatically creating a CLI for your Python programs.^{5}
 Fire was made by Google…surprise
 Codebase
My Take
Given how many problems in AI today come from data that is bad (illformed, biased, incomplete, etc), I find the GTN architecture to be very interesting. Specifically the aspect of theoretically infinite data generation for RL tasks is intriguing to me.
From the paper:
GTNs are exciting because they can encode a rich set of possible environments with minimal assumptions, ranging from labeled data for supervised learning to (in theory) entire complex virtual RL domains (with their own learned internal physics)
Appendix K address my biggest curiosity about this paper, how does the learner model perform when the generator is producing unrealistic images?
Many of the images generated by GTNs are unrecognizable (e.g. as digits)
They then go on to say
Thus, a network can learn to get over 98% accuracy on MNIST training only on unrecognizable images.
They clearly are undecided as to the reasoning for this, because they list multiple possible “camps” of thought as to why this is possible.

Weight normalization code https://github.com/uberresearch/GTN/blob/19799828d4ddd6b4e1fe837e38bb9ce4d3e588ce/models.py#L461 ↩

I think this is one of the most interesting parts of the paper, discussed more later ↩

NAO Luo 2018 https://arxiv.org/abs/1808.07233 ↩

Graph HyperNetworkshttps://arxiv.org/abs/1810.05749

I use Fire at work all the time now. ↩