10

I have a neural network with a fixed architecture (let's call it Architecture A). I also have two datasets, Dataset 1 and Dataset 2, both of which are independently and identically distributed (i.i.d.). I’m exploring how training strategies affect performance under the following scenarios:

  1. Model 1 (Sequential Training):

    • I train the neural network on Dataset 1 until it converges to its global optimum (assume sufficient epochs).
    • I save the trained model, then continue training it on Dataset 2 for a large number of epochs until it again reaches the global optimum.
  2. Model 2 (Combined Training):

    • I combine Dataset 1 and Dataset 2 into a single dataset.
    • I train a fresh instance of the same neural network (Architecture A) on this combined dataset for a sufficient number of epochs until it reaches the global optimum.

For this scenario, I’m assuming:

  • The accuracies of the models are strictly monotonically increasing with respect to the size of the training data in general (i.e., more data leads to better performance).
  • The optimization process is ideal, and both models successfully reach their respective global optima.

Questions:

  1. Will Model 2 (trained on the combined dataset) generally perform better than Model 1 (sequentially trained)? Why or why not?
  2. What factors (e.g., overfitting, catastrophic forgetting, or data distribution) might influence the relative performance of these two approaches?

I’d like to understand the theoretical reasoning and practical implications behind this. Any insights or references to relevant concepts (e.g., transfer learning, optimization dynamics) would be greatly appreciated!

4 Answers4

8

Well, it is a fundamental question of statistics.

When you take a sample, it could have come from any distribution that satisfies its functional constraints.

For example, you have points (1, 1), (2, 2), and (3, 3). Quite a good number of functions have these points as their solution. Now you take another point, (4, 4), (5, 5) and (6, 6). There are still quite a good number of functions for this as well. But when you combine them, the number of possible distributions satisfying all these 6 points together reduces. Hence our approximation for predicting the true distribution increases.

So for Q1, I would go with Model 2 as it has a relatively large number of data points and the functional approximation done by ANN would be a better approximation. Also for the first case, the functional approximation achieved by training dataset 1 only, wouldn't guarantee that it would get a good fit for dataset 2. Also if you start training for Dataset 2 it might forget the possible distribution adjusted by ANN for Dataset 1.

So for Q2, model 2 has better grip of the dataset, so its less prone to overfitting, catastrophic forgetting, or data distribution, etc.

Aviral Verma
  • 919
  • 1
  • 4
8

One important factor is the data distribution and its quantity.

If the 2 datasets come from the same distribution, better train the model with all the data as the more data a deep learning model can use, the better it will learn and the less it will overfit.

But if the second dataset is from another distribution and specially if the number of samples is much less compared to the first dataset, then it would be better to first train on the first dataset and after that train on the second dataset.

It is what we do when we start for example with a general CNN model originally trained with a lot of images from internet and then fine-tune it on a smaller dataset of medical images (e.g. RMI or X-rays) to get good classifiction scores on this special kind of images.

If we just mixed here both datasets the medical images would be "lost" among the general ones and the scores on the medical images would not be very good.

And in order that the model doesn't forget what it learned from the first dataset one could add some images of the first dataset to the second dataset for the second round of training or one could only fine-tune (update) the final layers and freeze the other ones to keep the previous general knowledge learned during the first training.

rehaqds
  • 1,801
  • 4
  • 13
5

Intuition Behind Model Training in General

When training a model on a dataset, it will minimize the loss specific to this dataset. That means it's optimizing for the patterns and structures it finds (in the current dataset only) by adapting the trainable parameters (aka weights) to minimize the loss, while starting either from scratch (usually random values) or with a weights from a already trained model. It does not care about what it 'learned' before, but tries to efficiently adapt the model parameters such that the loss gets smaller. It just takes this 'knowledge' as a starting point and keeps only what is helpful and necessary to reduce the loss. So this means it uses any parameters that were previously used to detect other patterns in the previous trainings to detect new patterns needed to reduce the loss for the new dataset. Note, that it effectively does not 'know' which of the parameters might be already used to remember any old patterns. It just observes the loss (difference between generated output vs the wanted output of the dataset).

A typical model operates in a high-dimensional parameter space, but to visualize the process, let’s simplify: Assume the model only has two parameters, X and Y — like a 2D coordinate system. Each dataset defines a unique loss surface (like hills and valleys in this space). For any given position (X, Y), the height (Z-axis) represents the loss.

Training is the process of moving across this surface, guided by the loss gradient, to find a local or global minimum (i.e., the best parameter values that reduce error). Typically, the model will follow a relatively straight path towards the nearest optimum.

What Happens When Mixing Datasets?

If we mix Dataset 1 and Dataset 2, assuming equal probability sampling (e.g. both datasets are of equal size), the combined loss becomes: loss_combined(x, y) = 0.5 * loss1(x, y) + 0.5 * loss2(x, y). So now, the model isn't minimizing just loss1 or loss2, but the average loss of both datasets. The new minimum does not necessarily be in the middle of the two, but e.g. could be where both loss function of dataset 1 and dataset 2 had a local minimum.

Note that the loss surface optimized for during training is always the average loss over randomly sampled entries from the training dataset at the given points, being the model parameter values in multi-dimensional space. This serves as an estimate of the true loss function. We usually don't know the true loss function, for which we would need all possible inputs and outputs. And we usually only can train on a subset of the often infinitely large dataset (this is what our inference data is drawn from) due to limited compute power and time anyways.

The smaller the dataset, the less accurate this estimate becomes, resulting in a "noisier" and more irregular loss surface.

Conversely, the larger the dataset, the smoother and more representative the estimated loss surface is, more closely reflecting the true underlying loss landscape during training.

Why Sequential Finetuning Can Fail

Suppose we first train on Dataset 1 and reach the optimum at (ox1, oy1). Then, we finetune on Dataset 2 and the new optimum is at (ox2, oy2). But the true optimum of the combined loss is at (ox3, oy3).

The problem: Unless (ox3, oy3) lies directly on the path between the first two optima, it's unlikely the model will reach that point during sequential training. Even early stopping based on a validation set won’t necessarily steer it towards the combined optimum.

This is why finetuning often includes data from the original dataset — to keep the model generalizing over both.

Answers to Your Questions

Q1: How does Model 2 compare to Model 1?

  • Model 2, trained on the combined dataset, will generally perform better on data resembling the combined distribution.

  • Model 1, finetuned only on Dataset 2, might perform better on Dataset 2-like data but worse overall.

  • There's also a risk that Model 1 gets stuck in a local minimum of Dataset 1, making it suboptimal for Dataset 2 — especially if the learning rate is too low.

  • Since training is stochastic, results can vary. Unless the loss function has a dominant global minimum, we often settle for good-enough local optima.

Q2: What about Catastrophic Forgetting and Overfitting?

Catastrophic Forgetting

Yes, this happens especially if Dataset 2 is very different from Dataset 1.

Example: Train a language model on Language A, then continue training on Language B — the model will likely "forget" Language A. The more different the datasets are, the stronger this effect.

Overfitting

Related, but different. If you monitor validation loss and stop training when it rises, overfitting is less likely. Overfitting is mostly a risk when the dataset is small relative to model size. Training on the combined dataset helps mitigate this, as it provides more diverse data.

On Data Distribution Shift

Differences in dataset composition can pull the model in different directions.

Example:

Dataset 1: 80% cats, 10% dogs, 10% birds

Dataset 2: 90% dogs, 10% cats, 0% birds

The model might learn shortcuts — e.g., always guessing "cat" for Dataset 1 and "dog" for Dataset 2 — without really learning image features. This is often the most efficient change for the model parameters to achieve a lower average loss value. After this adaption the model learns to further refine the estimate based one image features. But this takes much longer, since the back propagation will first adapt the last layers of any network which have a much higher influence on the loss. When birds disappear from Dataset 2, the model may "forget" that birds exist, i.e. the average probability of birds in the outcome will quickly adapt towards 0% before the model adapts to the image features of Dataset 2. This would be catastrophic forgetting in action.

Final Thoughts

  • Always try to train last on the data most similar to inference data.
  • Training updates the model by following gradients (directions for parameter updates).
  • Validation just ensures we don’t overtrain and start memorizing instead of generalizing.
SDwarfs
  • 324
  • 5
4

Generally speaking, it is a bad idea to train until convergence to a global optimum (i.e., until minimal training loss). That tends to cause overfitting. Instead, typically it is better to use early stopping, which stops training earlier acts as a form of regularization. So it is better to fix a total amount of training, and compare two models that have both seen about the same amount of training.

In any case, I predict that Model 2 may perform better, or both may perform similarly (but Model 1 will not perform better).

We can consider a more extreme version of your question. Suppose you have $n$ data points $x_1,\dots,x_n$. Consider model A, which is obtained by 10 epochs of training, where in each epoch you train on the $n$ points; vs model B, which is obtained by training on $x_1$ 10 times, then $x_2$ 10 times, etc. Generally, model A is likely to perform better than B (or they might perform about the same). Or consider model A', which is obtained by 1 epoch of training, where you train on each of the $n$ points once; vs model B', which is obtained by training on $x_1$ 10 times, then $x_2$ 10 times, etc., ending by training on $x_{n/10}$ 10 times. Generally, model A' is likely to perform better than model B' (or about the same). And you can notice that A' is what you get by stopping the training of model A 1/10 of the way through, and B' is what you get by stopping the training of model B 1/10 of the way.

D.W.
  • 3,651
  • 18
  • 43