diff --git a/CTAI_model/net/train.py b/CTAI_model/net/train.py index 0bfa334..d45d1b3 100644 --- a/CTAI_model/net/train.py +++ b/CTAI_model/net/train.py @@ -43,7 +43,7 @@ optimizer = torch.optim.Adam(unet.parameters(), learn_rate) def train(): global res - dataloaders = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0) + dataloaders = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0) for epoch in range(epochs): dt_size = len(dataloaders.dataset) epoch_loss, epoch_dice = 0, 0