diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index d3dbd9287bd9..0631cc5dee54 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -783,6 +783,17 @@ def bar(prg): txt_dir=args.output_dir + "/text_encoder_trained" if os.path.exists(txt_dir): subprocess.call('rm -r '+txt_dir, shell=True) + else: + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + ) + frz_dir=args.output_dir + "/text_encoder_frozen" + pipeline.save_pretrained(args.output_dir) + if args.train_text_encoder and os.path.exists(frz_dir): + subprocess.call('mv -f '+frz_dir +'/*.* '+ args.output_dir+'/text_encoder', shell=True) + subprocess.call('rm -r '+ frz_dir, shell=True) if os.path.exists(args.captions_dir+'off'): subprocess.call('mv '+args.captions_dir+'off '+args.captions_dir, shell=True)