Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can dice and bce loss work on a multi-class task? #9

Open
Dr-Cube opened this issue Dec 30, 2019 · 5 comments
Open

Can dice and bce loss work on a multi-class task? #9

Dr-Cube opened this issue Dec 30, 2019 · 5 comments

Comments

@Dr-Cube
Copy link

Dr-Cube commented Dec 30, 2019

Thanks for the great implementation code.

I am confusing about the loss function. As far as I can see dice and bce are both used in binary-class task. Can they work well on multi-class task? From your code I can see the losses work ok, but what about bigger data set.

I tried F.cross_entropy(), but it gives me this: RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size: : [36, 4, 224, 224]. Could you please tell me whats wrong? thx

def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)
    target_long = target.type(torch.LongTensor)
    ce = F.cross_entropy(pred, target_long.cuda())

    # pred = F.sigmoid(pred)
    pred = torch.sigmoid(pred)
    dice = dice_loss(pred, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)

    return loss
@usuyama
Copy link
Owner

usuyama commented May 19, 2020

You're right. The current code uses BCE, so each pixel can have multiple classes i.e. multi-labels.

To make it a single class for each pixel i.e. multi-class, you can use CE. I think you need to use reshape/view to 2d.

@ckolluru
Copy link

ckolluru commented May 24, 2020

Why are the metrics multiplied by the batch size, added cumulatively and then divided by the total number of samples during printing?

It seems like this will print a scaled version of the average metric value (from print_metrics), depending on the batch size. Please correct me if I'm wrong.

@sarmientoj24
Copy link

Hi, I am having a problem dealing with a multi-class task where dimensions are like these:

MASK TARGET: torch.Size([4, 1, 600, 900, 3])
OUTPUT: torch.Size([4, 5, 600, 900]

@ckolluru Have you created your loss function for multiclass already?

@shahzad-ali
Copy link

Why are the metrics multiplied by the batch size, added cumulatively and then divided by the total number of samples during printing?

The reason is some batches (i.e. the last batch) may have fewer training examples than all the other batches. Dividing the product metric * batch_size by total_samples is a better estimate of a used metric for a complete epoch. Skimming through the example Training a Classifier from PyTorch tutorials reveals that the same strategy was used in section "4. Train the network" is the important one where the statistics were printed.

It seems like this will print a scaled version of the average metric value (from print_metrics), depending on the batch size. Please correct me if I'm wrong.

That's true! You may get a better insight into this topic by reading How is the loss for an epoch reported?

@AlphonsG
Copy link

If you want to use cross entropy make sure you're not applying sigmoid function beforehand.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants