I have a very unique multi-label multi-class problem. I have a neural network that outputs 6 logits. The number of classes that we are trying to predict are 2^6 i.e. I am encoding my output as a binary number. The reason for this is that if I just have my last layer be a torch.nn.Linear layer of 64 neurons my model becomes too big. Now, I am also working with a very unbalanced dataset, where some labels are more frequent that others. I have a weight torch.Tensor of size 64 that I try to pass to the weight function argument of torch.nn.functional.cross_entropy but I get a error:
RuntimeError: cross_entropy: weight tensor should be defined either for all 6 classes or no classes but got weight tensor of shape: [64]
How do I assign a weight for each of the 2^6 permutations of the output?