Learning a Sketch Tensor Space for Image Inpainting of Man-made Scenes (ICCV 2021)
Overview
We learn an encoder-decoder model, which encodes a Sketch Tensor (ST) space consisted of refined lines and edges. Then the model recover the masked images by the ST space.
News
- Release the inference codes.
- Training codes.
Now, this work has been further improved in ZITS (CVPR2022).
Preparation
- Preparing the environment.
- Download the pretrained masked wireframe detection model LSM-HAWP (retrained from HAWP CVPR2020).
- Download weights for different requires to the ‘check_points’ fold. P2M (Man-made Places2), P2C (Comprehensive Places2), shanghaitech (Shanghaitech with all man-made scenes).
- For training, we provide irregular and segmentation masks (download) with different masking rates. And you should define the mask file list before the training (flist_example.txt).
Training
Since the training code is rewritten, there are some differences compared with the test code.
Training uses src/models.py while testing uses src/model_inference.py.
Image are valued in -1 to 1 (training) and 0 to 1 (testing).
Masks are always concated to the inputs.
- Generating wireframes by lsm-hawp.
CUDA_VISIBLE_DEVICES=0 python lsm_hawp_inference.py --ckpt_path <best_lsm_hawp.pth> --input_path <input image path> --output_path <output image path>
-
Setting file lists in training_configs/config_MST.yml (example: flist_example.txt).
- Train the inpainting model with stage1 and stage2.
python train_MST_stage1.py --path <model_name> --config training_configs/config_MST.yml --gpu 0 python train_MST_stage2.py --path <model_name> --config training_configs/config_MST.yml --gpu 0
For DDP training with multi-gpus:
python -m torch.distributed.launch --nproc_per_node=4 train_MST_stage1.py --path <model_name> --config training_configs/config_MST.yml --gpu 0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train_MST_stage2.py --path <model_name> --config training_configs/config_MST.yml --gpu 0,1,2,3
Test for a single image
python test_single.py --gpu_id 0 \
--PATH ./check_points/MST_P2C \
--image_path <your image path> \
--mask_path <your mask path (0 means valid and 255 means masked)>
Object Removal Examples
Comparisons