4

I'm working on a binary classification task where the goal is to determine whether a tissue contains malignant cells Each instance in my dataset consists of

  • a microscope image of the cell
  • a small set of tabular metadata including an identifier of the imaging session and a binary feature indicating whether the cell was treated with fluorescent particles or not

I'm considering a hybrid neural network combining a CNN to extract features from the image and either a TabNet model or a fully connected MLP to process the tabular data My idea is to concatenate the features from both branches and pass them to a shared classification head My questions

  1. how should I handle the identifier? should I embed it or drop it completely
  2. are there alternative ways to model the tabular branch beyond MLP or TabNet especially with very few tabular features
  3. any best practices when combining CNN image embeddings with tabular data? Thanks in advance for any suggestions or shared experiences

2 Answers2

4

Image Encoder (e.g., CNN) I will suggest the trial use of Convolutional Neural Network (like ResNet, EfficientNet, or MobileNet) to extract features from the microscope image. Enter: the cell details and image Outcome of: certain vector features You can start with a uncertain model (on ImageNet) and dialogue with it and filter.

Hadeynike
  • 147
  • 1
4

how should I handle the identifier? should I embed it or drop it completely

If it's a random ID for that experiment, or something that is not meaningful for the task at hand, better to drop it. It will have no impact at best, and at worst it will noise the features or introduce a risk of overfitting to spurious patterns.

If you think it carries information that could be useful for the task, then I would extract those details and encode them as input features.

An example for this task might be ID timestamp information - perhaps it correlates with strong temperature variations during the day, which affects equipment calibration. In that case, it might help to extract the AM/PM information from the ID and use it as an input feature.

Consider whether the ID could inadvertently leak target information, like if it hints at the outcome you are trying to predict. It could be direct leakage, like revealing whether or not the cells were treated for that experiment. Indirect leakage could be when a specific user ID always uses treated cells - the model will use the ID as a proxy for the answer, resulting in an optimistic measure of model performance.

are there alternative ways to model the tabular branch beyond MLP or TabNet especially with very few tabular features

My preference is to start with simply encoding the columns on a per-column basis. Consider each column in turn, and the type of useful information it has, if any. Then encode each column, choosing an encoding method that plays to the strengths of that column's data.

If you encode a column as a binary feature, you could insert it at the classifier layer along with the image embedding.

For columns encoded as ordinal features, you could define a new embedding layer that maps each number to an N-dimensional embedding (torch.nn.Embedding), to be combined with the image embeddings.