Pytorch implementation of DPViT (accepted at Neurips'23)
Mitigating the Effect of Incidental Correlations on Part-based Learning
Gaurav Bhatt*, Deepayan Das, Leonid Sigal, Vineeth N Balasubramanian
The code has been tested with the following environment:
git clone https://github.com/GauravBh1010tt/DPViT.git
cd DPViT
conda env create --name dpvit --file=environment.yml
source activate dpvit
We extract the weakly-supervised masks using package. Since, RemBG is not optimized for batch inference, we modify thier code, and is present in DPViT/data_utils/rembg
.
cd data_utils
python rembg/nbg_replace.py --dataset=miniIM --img_dir='path_to_train_folder' --batch_size=20
The image and it corresponding mask is saved a a single image. This is to speedup dataloading process over slurm by minimizing the number of file I/O calls.
Use the mini-imagenet-tools to create imagenet dataset. Please note that all datasets should have the format similar to ImagNet and should look like this:
|-- miniimagenet
| |-- train
| | |-- n908761
| | |-- n453897
| | |-- ...
Download the ImageNet-9 dataset from ImageNet-9. The structure of files looks like this:
|-- in9
| |--train
| | |-- 00_dog
| | |-- 01_bird
| | | |-- ...
| |-- bg_challenge
| | |-- mixed_same
| | | |-- 00_dog
| | | |-- 01_bird
| | | |-- ...
| | |-- mixed_rand
| | |-- ...
bash scripts/run_local.sh
Update the hyper-parameters in the run_local.sh file
bash scripts/run_slurm.sh
You need to update the cluster-specific configuration in run_slurm.sh
file.
We train DPViT on 4 A100 GPUs with 40 GB of VRAM each. Try setting the batch size according to your spefications.
By default the visualizations are saved inside the exp folder: <exp_name>/visualization_epoch<#>
. The inference can be done on given images inside the img_viz
folder using the following command:
python eval/eval_dpvit.py --ckp_path="path to saved checkpoint" --eval=0 --viz=1 --image_path='img_viz'
python eval/eval_dpvit.py --ckp_path="path to saved checkpoint" --eval=1 --num_shots=5
python eval/imagenet_cls.py --pretrained_weights="saved model" --data_path "data/in9" --partition "bg_challenge/mixed_same/val" --num_classes 9
Choose one of the following partitions from ImageNet-9 : mixed_same, mixed_rand, ...
If you find this repo useful, please cite:
@inproceedings{bhatt2023mitigating,
title={Mitigating the Effect of Incidental Correlations on Part-based Learning},
author={Bhatt, Gaurav and Das, Deepayan and Sigal, Leonid and Balasubramanian, Vineeth N},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023}
}