-
Notifications
You must be signed in to change notification settings - Fork 309
/
Copy pathsample_controlNet.py
129 lines (101 loc) · 4.35 KB
/
sample_controlNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from transformers import CLIPVisionModelWithProjection,CLIPImageProcessor
from diffusers.utils import load_image
import os,sys
from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import StableDiffusionXLControlNetImg2ImgPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from kolors.models.controlnet import ControlNetModel
from diffusers import AutoencoderKL
from kolors.models.unet_2d_condition import UNet2DConditionModel
from diffusers import EulerDiscreteScheduler
from PIL import Image
import numpy as np
import cv2
from annotator.midas import MidasDetector
from annotator.dwpose import DWposeDetector
from annotator.util import resize_image,HWC3
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def process_canny_condition( image, canny_threods=[100,200] ):
np_image = image.copy()
np_image = cv2.Canny(np_image, canny_threods[0], canny_threods[1])
np_image = np_image[:, :, None]
np_image = np.concatenate([np_image, np_image, np_image], axis=2)
np_image = HWC3(np_image)
return Image.fromarray(np_image)
model_midas = None
def process_depth_condition_midas(img, res = 1024):
h,w,_ = img.shape
img = resize_image(HWC3(img), res)
global model_midas
if model_midas is None:
model_midas = MidasDetector()
result = HWC3( model_midas(img) )
result = cv2.resize( result, (w,h) )
return Image.fromarray(result)
model_dwpose = None
def process_dwpose_condition( image, res=1024 ):
h,w,_ = image.shape
img = resize_image(HWC3(image), res)
global model_dwpose
if model_dwpose is None:
model_dwpose = DWposeDetector()
out_res, out_img = model_dwpose(image)
result = HWC3( out_img )
result = cv2.resize( result, (w,h) )
return Image.fromarray(result)
def infer( image_path , prompt, model_type = 'Canny' ):
ckpt_dir = f'{root_dir}/weights/Kolors'
text_encoder = ChatGLMModel.from_pretrained(
f'{ckpt_dir}/text_encoder',
torch_dtype=torch.float16).half()
tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half()
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half()
control_path = f'{root_dir}/weights/Kolors-ControlNet-{model_type}'
controlnet = ControlNetModel.from_pretrained( control_path , revision=None).half()
pipe = StableDiffusionXLControlNetImg2ImgPipeline(
vae=vae,
controlnet = controlnet,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
force_zeros_for_empty_prompt=False
)
pipe = pipe.to("cuda")
pipe.enable_model_cpu_offload()
negative_prompt = 'nsfw,脸部阴影,低分辨率,jpeg伪影、模糊、糟糕,黑脸,霓虹灯'
MAX_IMG_SIZE=1024
controlnet_conditioning_scale = 0.7
control_guidance_end = 0.9
strength = 1.0
basename = image_path.rsplit('/',1)[-1].rsplit('.',1)[0]
init_image = Image.open( image_path )
init_image = resize_image( init_image, MAX_IMG_SIZE)
if model_type == 'Canny':
condi_img = process_canny_condition( np.array(init_image) )
elif model_type == 'Depth':
condi_img = process_depth_condition_midas( np.array(init_image), MAX_IMG_SIZE )
elif model_type == 'Pose':
condi_img = process_dwpose_condition( np.array(init_image), MAX_IMG_SIZE)
generator = torch.Generator(device="cpu").manual_seed(66)
image = pipe(
prompt= prompt ,
image = init_image,
controlnet_conditioning_scale = controlnet_conditioning_scale,
control_guidance_end = control_guidance_end,
strength= strength ,
control_image = condi_img,
negative_prompt= negative_prompt ,
num_inference_steps= 50 ,
guidance_scale= 6.0,
num_images_per_prompt=1,
generator=generator,
).images[0]
condi_img.save( f'{root_dir}/controlnet/outputs/{model_type}_{basename}_condition.jpg' )
image.save(f'{root_dir}/controlnet/outputs/{model_type}_{basename}.jpg')
if __name__ == '__main__':
import fire
fire.Fire(infer)