Generative Teaching Networks
- Training data is a limitation of ML/AI - limited size, inherent biases. This paper explores the possibility that generating data might be the answer to the limitations of data.
- Notably the generation of data may also lead to other desirable attributes such as
- Faster evaluation of architectures in NAS
- Better generalizability
- Applications in Reinforcement Learning, Supervised Learning, Neural Architecture Search
Unlike a GAN, two models cooperate instead of compete. The learning process consists of two loops: an inner loop and an outer loop trained together with meta-learning via nested optimization.
- Generator takes Gaussian Noise and a label as input and produces a sample. (Could also generate the label)
- Learner is trained for a fixed number of steps (SGD with momentum)
- Authors also learn from specific curriculums in the inner loop (3 total) which are parameterized as a part of the entire network
- Loss: MSE for regression, cross-entropy for classification
- inner loop objective does not depend on outer loop objective
- Learner is then trained on real data which constructs the meta-training-loss
- Gradient of the meta-training loss is back-propagated through the entire network
- Hyperparameters of the learner as also parameterized and included in gradient calculation.
- Weight normalization improves stability of meta learning
- Curriculum learning is useful for GTN
- Faster NN training times for CIFAR-10 and MNIST and cart pole than compared to training with real data.
- Faster NAS for CIFAR-10 with limited compute
Claim 1: Weight Normalization improves stability of meta-learning
The claim is that weight normalization is analogous to batch normalization. Authors report that weight normalization greatly increases the stability of the training process.1
# weight normalization layers in Pytorch def make_weight_norm_layer(base_cls): class WeightNorm(base_cls): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.weight_g = nn.Parameter(torch.norm_except_dim(self.weight, 2, 0).data) def forward(self, x): x = super().forward(x) x = x * (self.weight_g / torch.norm_except_dim(self.weight, 2, 0)).transpose(1, 0) return x return WeightNorm class Linear(nn.Linear): def reset_parameters(self): super().reset_parameters() torch.nn.init.kaiming_normal_(self.weight) from exconv import ExConv2d class Conv2d(nn.Conv2d): def reset_parameters(self): super().reset_parameters() torch.nn.init.kaiming_normal_(self.weight) def forward(self, x): if self.bias is None and self.groups == 1 and not getattr(self, "disable_exconv", False): return ExConv2d.apply(x, self.weight, self.padding, self.stride, self.dilation, self.groups) else: return super().forward(x) WeightNormLinear = make_weight_norm_layer(Linear) WeightNormConv2d = make_weight_norm_layer(Conv2d)
Claim 2: Curriculum Learning is useful for GTN
The authors report that curriculum learning is is useful for performance.
- 4 curriculums studied
- No curriculum
- trains a generator to output synthetic training data by sampling a noise vector from a Gaussian distribution
- Can theoretically generate infinite training data, but performs the worst of all learning measures
- All shuffled
- learns fixed set of 4096 inputs presented in a random order without replacement
- e.g. learning controls data but not order or batch makeup
- Shuffled Batch
- learns 32 batches fo 128 inputs each
- e.g. learns to control what’s in each batch but not order
- Full Curriculum
- learns deterministic sequence of 32 batches of 128 samples
- e.g. controls order and content of batches
- No curriculum
class CGTN(nn.Module): def __init__(self, generator, num_inner_iterations, generator_batch_size, noise_size, evenly_distributed_labels=False, meta_learn_labels=False): super().__init__() self.generator_batch_size = generator_batch_size self.generator = generator if evenly_distributed_labels: labels = torch.arange(num_inner_iterations * generator_batch_size) % 10 labels = torch.reshape(labels, (num_inner_iterations, generator_batch_size)) self.curriculum_labels = nn.Parameter(labels, requires_grad=False) else: self.curriculum_labels = nn.Parameter(torch.randint(10, size=(num_inner_iterations, generator_batch_size), dtype=torch.int64), requires_grad=False) self.curriculum_labels_one_hot = torch.zeros(num_inner_iterations, generator_batch_size, 10) self.curriculum_labels_one_hot.scatter_(2, self.curriculum_labels.unsqueeze(-1), 1) self.curriculum_labels_one_hot = nn.Parameter(self.curriculum_labels_one_hot, requires_grad=meta_learn_labels) # TODO: Maybe learn the soft-labels? self.curriculum = nn.Parameter(torch.randn((num_inner_iterations, generator_batch_size, noise_size), dtype=torch.float32)) self.generator = torch.jit.trace(self.generator, (torch.rand(generator_batch_size, noise_size + 10),)) def forward(self, it): label = self.curriculum_labels_one_hot[it] noise = torch.cat([self.curriculum[it], label], dim=-1) x = self.generator(noise) if not x.requires_grad: label = label.detach() return x, label class CGTNAllShuffled(CGTN): def forward(self, it): all_images = torch.reshape(self.curriculum, (-1,) + self.curriculum.shape[2:]) all_labels = torch.reshape(self.curriculum_labels_one_hot, (-1,) + self.curriculum_labels_one_hot.shape[2:]) idx = torch.randint(len(all_images), size=(self.generator_batch_size,), device=all_images.device) noise = all_images[idx] labels = all_labels[idx] noise = torch.cat([noise, labels], dim=-1) x = self.generator(noise) return x, labels class CGTNBatchShuffled(CGTN): def forward(self, it): idx = torch.randint(len(self.curriculum), size=()) noise = self.curriculum[idx] labels = self.curriculum_labels_one_hot[idx] noise = torch.cat([noise, labels], dim=-1) x = self.generator(noise) return x, labels
Claim 3: Faster NN training times for CIFAR, MNIST, Cart-Pole
- GTN for supervised learning
- Experiment setup:
- Data: MNIST images or GTN generated MNIST images
- 3 setups trained
- Real data: learners trained with real samples randomly sampled
- Dataset Distillation: training learners with synthetic data where training examples are encoded directly as tensors optimized by the meta-objective (from Wang et al 2019b)
- GTN: Authors method
- GTN outperforms benchmarks for MNIST
- Author notes that the SOTA performance is not important but rather how well and inexpensively the model is trained itself. This is important for NAS. Provides the ability to identify and evaluate architectures quickly
- Very interestingly, the pictures generated by the GTN are not obviously pictures of images.2
# setup 3 class GTN(nn.Module): def __init__(self, generator, generator_batch_size, noise_size): super().__init__() self.generator = generator self.generator_batch_size = generator_batch_size self.noise_size = noise_size self.generator = torch.jit.trace(self.generator, (torch.rand(generator_batch_size, noise_size + 10),)) def forward(self, it): curriculum_labels = torch.randint(10, size=(self.generator_batch_size,), dtype=torch.int64, device="cuda") curriculum_labels_one_hot = torch.zeros(self.generator_batch_size, 10, device="cuda") curriculum_labels_one_hot.scatter_(1, curriculum_labels.unsqueeze(-1), 1) curriculum_labels_one_hot = curriculum_labels_one_hot.to("cuda") noise = torch.cat([torch.randn(self.generator_batch_size, self.noise_size, device="cuda"), curriculum_labels_one_hot], dim=-1) x = self.generator(noise) return x, curriculum_labels_one_hot # Setup 2 class DatasetDistillation(nn.Module): def __init__(self, num_inner_iterations, generator_batch_size, img_shape): super().__init__() self.curriculum_labels = nn.Parameter(torch.randint(10, size=(num_inner_iterations, generator_batch_size), dtype=torch.int64), requires_grad=False) self.curriculum_labels_one_hot = torch.zeros(num_inner_iterations, generator_batch_size, 10) self.curriculum_labels_one_hot.scatter_(2, self.curriculum_labels.unsqueeze(-1), 1) self.curriculum_labels_one_hot = nn.Parameter(self.curriculum_labels_one_hot, requires_grad=False) self.curriculum = nn.Parameter(torch.randn((num_inner_iterations, generator_batch_size,) + img_shape, dtype=torch.float32)) def forward(self, it): x = self.curriculum[it] return torch.tanh(x) * 2, self.curriculum_labels_one_hot[it]
Claim 4: Faster NAS with limited compute
- Author claims that GTNs are applicable for NAS because they reduce the computational cost of evaluation
- In other words, the GTN is evaluated on its effectiveness as a proxy evaluation of candidate architectures.
- the authors measured the effectiveness of GTN’s ability to reduce training cost by calculating the Spearman rank correlation between 128 steps of training on GTN synthetic data vs 100 epochs of real data
- Correlation: 0.3606 or .5582 if considering top 50% of architectures
- Experiment Setup:
- GTN-NAS improves the NAS state of the art by finding higher-performing architectures than comparable methods like weight sharing and Graph HyperNetworks4
# NAS setup used in the paper class AutoML(nn.Module): def __init__(self, generator, optimizers, initial_batch_norm_momentum=0.9): super().__init__() self.generator = generator self.optimizers = torch.nn.ModuleList(optimizers) self.batch_norm_momentum_logit = nn.Parameter(torch.as_tensor(inner_optimizers.inv_sigmoid(0.9))) @property def batch_norm_momentum(self): return torch.sigmoid(self.batch_norm_momentum_logit) def sample_learner(self, input_shape, device, allow_nas=False, learner_type="base", iteration_maps_seed=False, iteration=None, deterministic=False, iterations_depth_schedule=100, randomize_width=False): if iteration_maps_seed: iteration = iteration - 1 encoding = [iteration % 6, iteration // 6] else: encoding = None if learner_type == "sampled": layers = min(4, max(0, iteration // iterations_depth_schedule)) model, encoding = sample_model(input_shape, layers=layers, encoding=encoding, blocks=2, seed=iteration if deterministic else None, batch_norm_momentum=0) tlogger.record_tabular("encoding", encoding) elif learner_type == "sampled4": model, encoding = sample_model(input_shape, layers=4, encoding=encoding, seed=iteration if deterministic else None, batch_norm_momentum=0) tlogger.record_tabular("encoding", encoding) elif learner_type == "base": model = Classifier(input_shape, batch_norm_momentum=0.0, randomize_width=randomize_width) elif learner_type == "base_fc": model = Classifier(input_shape, batch_norm_momentum=0.0, randomize_width=randomize_width, use_global_pooling=False) elif learner_type == "linear": model = models.LinearClassifier(input_shape) elif learner_type == "base_larger": model = models.ClassifierLarger(input_shape, batch_norm_momentum=0.0, randomize_width=randomize_width) elif learner_type == "base_larger2": model = models.ClassifierLarger2(input_shape, batch_norm_momentum=0.0, randomize_width=randomize_width) elif learner_type == "base_larger3": model = models.ClassifierLarger3(input_shape, batch_norm_momentum=0.0, randomize_width=randomize_width) elif learner_type == "base_larger3_global_pooling": model = models.ClassifierLarger3(input_shape, batch_norm_momentum=0.0, randomize_width=randomize_width, use_global_pooling=True) elif learner_type == "base_larger4_global_pooling": model = models.ClassifierLarger4(input_shape, batch_norm_momentum=0.0, randomize_width=randomize_width, use_global_pooling=True) elif learner_type == "base_larger4": model = models.ClassifierLarger4(input_shape, batch_norm_momentum=0.0, randomize_width=randomize_width, use_global_pooling=False) else: raise NotImplementedError() return Learner(model=model.to(device), optimizer=np.random.choice(self.optimizers)), encoding
- Weight normalization can be used to stabilize meta-training in many different environments
- Many possible extensions of GTNs
- Unsupervised learning for creating a useful embedding function
- stabilize GAN training and prevent mode collapse
- Introducing a closed loop curriculum that is able to adapt to the performance of the learner.
- i.e. in the beginning, learning would be easier and progressively get more complex as the learner’s performance improved. This is analogous to a tutor’s approach to teaching a student. They don’t just jump straight to calculus, they teach algebra first.
- The generated data does not need to resemble the training data in order to achieve good results. There are many examples in the paper about why. See appendix.
- GTNs have the ability to bootstrap learners in complex ways such that many aspects of learning can be combined. By controlling the data that is ingested by the learner, the learning itself is controlled.
they also one day could be a key to creating AI-generating algorithms, which seek to bootstrap themselves from simple initial conditions to powerful forms of AI by creating an open-ended stream of challenges (learning opportunities) while learning to solve them (Clune, 2019).
- Bootstrapping learners from simple initial conditions
- Creating NN “on-demand”
Some notes on how the GTN is implemented. Code can be found here -> Codebase
- Language: Python/C++
- DL Framework: Pytorch
- Parallelism: Horovod/MPI
Another notable point about the implementation is the use of the Fire library which is very helpful in automatically creating a CLI for your Python programs.5
- Fire was made by Google…surprise
Given how many problems in AI today come from data that is bad (ill-formed, biased, incomplete, etc), I find the GTN architecture to be very interesting. Specifically the aspect of theoretically infinite data generation for RL tasks is intriguing to me.
From the paper:
GTNs are exciting because they can encode a rich set of possible environments with minimal assumptions, ranging from labeled data for supervised learning to (in theory) entire complex virtual RL domains (with their own learned internal physics)
Appendix K address my biggest curiosity about this paper, how does the learner model perform when the generator is producing unrealistic images?
Many of the images generated by GTNs are unrecognizable (e.g. as digits)
They then go on to say
Thus, a network can learn to get over 98% accuracy on MNIST training only on unrecognizable images.
They clearly are undecided as to the reasoning for this, because they list multiple possible “camps” of thought as to why this is possible.
Weight normalization code https://github.com/uber-research/GTN/blob/19799828d4ddd6b4e1fe837e38bb9ce4d3e588ce/models.py#L461 ↩
I think this is one of the most interesting parts of the paper, discussed more later ↩
I use Fire at work all the time now. ↩