Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
shivamshrirao
GitHub Repository: shivamshrirao/diffusers
Path: blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
4044 views
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
# http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
import inspect
16
from typing import Any, Callable, Dict, List, Optional, Union
17
18
import torch
19
from packaging import version
20
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
21
22
from ...configuration_utils import FrozenDict
23
from ...models import AutoencoderKL, UNet2DConditionModel
24
from ...schedulers import KarrasDiffusionSchedulers
25
from ...utils import (
26
deprecate,
27
is_accelerate_available,
28
is_accelerate_version,
29
logging,
30
randn_tensor,
31
replace_example_docstring,
32
)
33
from ..pipeline_utils import DiffusionPipeline
34
from . import StableDiffusionPipelineOutput
35
from .safety_checker import StableDiffusionSafetyChecker
36
37
38
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
40
EXAMPLE_DOC_STRING = """
41
Examples:
42
```py
43
>>> import torch
44
>>> from diffusers import StableDiffusionPipeline
45
46
>>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
47
>>> pipe = pipe.to("cuda")
48
49
>>> prompt = "a photo of an astronaut riding a horse on mars"
50
>>> image = pipe(prompt).images[0]
51
```
52
"""
53
54
55
class StableDiffusionPipeline(DiffusionPipeline):
56
r"""
57
Pipeline for text-to-image generation using Stable Diffusion.
58
59
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
60
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
61
62
Args:
63
vae ([`AutoencoderKL`]):
64
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
65
text_encoder ([`CLIPTextModel`]):
66
Frozen text-encoder. Stable Diffusion uses the text portion of
67
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
68
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
69
tokenizer (`CLIPTokenizer`):
70
Tokenizer of class
71
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
72
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
73
scheduler ([`SchedulerMixin`]):
74
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
75
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
76
safety_checker ([`StableDiffusionSafetyChecker`]):
77
Classification module that estimates whether generated images could be considered offensive or harmful.
78
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
79
feature_extractor ([`CLIPImageProcessor`]):
80
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
81
"""
82
_optional_components = ["safety_checker", "feature_extractor"]
83
84
def __init__(
85
self,
86
vae: AutoencoderKL,
87
text_encoder: CLIPTextModel,
88
tokenizer: CLIPTokenizer,
89
unet: UNet2DConditionModel,
90
scheduler: KarrasDiffusionSchedulers,
91
safety_checker: StableDiffusionSafetyChecker,
92
feature_extractor: CLIPImageProcessor,
93
requires_safety_checker: bool = True,
94
):
95
super().__init__()
96
97
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
98
deprecation_message = (
99
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
100
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
101
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
102
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
103
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
104
" file"
105
)
106
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
107
new_config = dict(scheduler.config)
108
new_config["steps_offset"] = 1
109
scheduler._internal_dict = FrozenDict(new_config)
110
111
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
112
deprecation_message = (
113
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
114
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
115
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
116
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
117
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
118
)
119
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
120
new_config = dict(scheduler.config)
121
new_config["clip_sample"] = False
122
scheduler._internal_dict = FrozenDict(new_config)
123
124
if safety_checker is None and requires_safety_checker:
125
logger.warning(
126
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
127
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
128
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
129
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
130
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
131
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
132
)
133
134
if safety_checker is not None and feature_extractor is None:
135
raise ValueError(
136
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
137
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
138
)
139
140
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
141
version.parse(unet.config._diffusers_version).base_version
142
) < version.parse("0.9.0.dev0")
143
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
144
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
145
deprecation_message = (
146
"The configuration file of the unet has set the default `sample_size` to smaller than"
147
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
148
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
149
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
150
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
151
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
152
" in the config might lead to incorrect results in future versions. If you have downloaded this"
153
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
154
" the `unet/config.json` file"
155
)
156
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
157
new_config = dict(unet.config)
158
new_config["sample_size"] = 64
159
unet._internal_dict = FrozenDict(new_config)
160
161
self.register_modules(
162
vae=vae,
163
text_encoder=text_encoder,
164
tokenizer=tokenizer,
165
unet=unet,
166
scheduler=scheduler,
167
safety_checker=safety_checker,
168
feature_extractor=feature_extractor,
169
)
170
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
171
self.register_to_config(requires_safety_checker=requires_safety_checker)
172
173
def enable_vae_slicing(self):
174
r"""
175
Enable sliced VAE decoding.
176
177
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
178
steps. This is useful to save some memory and allow larger batch sizes.
179
"""
180
self.vae.enable_slicing()
181
182
def disable_vae_slicing(self):
183
r"""
184
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
185
computing decoding in one step.
186
"""
187
self.vae.disable_slicing()
188
189
def enable_vae_tiling(self):
190
r"""
191
Enable tiled VAE decoding.
192
193
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
194
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
195
"""
196
self.vae.enable_tiling()
197
198
def disable_vae_tiling(self):
199
r"""
200
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
201
computing decoding in one step.
202
"""
203
self.vae.disable_tiling()
204
205
def enable_sequential_cpu_offload(self, gpu_id=0):
206
r"""
207
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
208
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
209
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
210
Note that offloading happens on a submodule basis. Memory savings are higher than with
211
`enable_model_cpu_offload`, but performance is lower.
212
"""
213
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
214
from accelerate import cpu_offload
215
else:
216
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
217
218
device = torch.device(f"cuda:{gpu_id}")
219
220
if self.device.type != "cpu":
221
self.to("cpu", silence_dtype_warnings=True)
222
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
223
224
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
225
cpu_offload(cpu_offloaded_model, device)
226
227
if self.safety_checker is not None:
228
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
229
230
def enable_model_cpu_offload(self, gpu_id=0):
231
r"""
232
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
233
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
234
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
235
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
236
"""
237
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
238
from accelerate import cpu_offload_with_hook
239
else:
240
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
241
242
device = torch.device(f"cuda:{gpu_id}")
243
244
if self.device.type != "cpu":
245
self.to("cpu", silence_dtype_warnings=True)
246
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
247
248
hook = None
249
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
250
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
251
252
if self.safety_checker is not None:
253
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
254
255
# We'll offload the last model manually.
256
self.final_offload_hook = hook
257
258
@property
259
def _execution_device(self):
260
r"""
261
Returns the device on which the pipeline's models will be executed. After calling
262
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
263
hooks.
264
"""
265
if not hasattr(self.unet, "_hf_hook"):
266
return self.device
267
for module in self.unet.modules():
268
if (
269
hasattr(module, "_hf_hook")
270
and hasattr(module._hf_hook, "execution_device")
271
and module._hf_hook.execution_device is not None
272
):
273
return torch.device(module._hf_hook.execution_device)
274
return self.device
275
276
def _encode_prompt(
277
self,
278
prompt,
279
device,
280
num_images_per_prompt,
281
do_classifier_free_guidance,
282
negative_prompt=None,
283
prompt_embeds: Optional[torch.FloatTensor] = None,
284
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
285
):
286
r"""
287
Encodes the prompt into text encoder hidden states.
288
289
Args:
290
prompt (`str` or `List[str]`, *optional*):
291
prompt to be encoded
292
device: (`torch.device`):
293
torch device
294
num_images_per_prompt (`int`):
295
number of images that should be generated per prompt
296
do_classifier_free_guidance (`bool`):
297
whether to use classifier free guidance or not
298
negative_prompt (`str` or `List[str]`, *optional*):
299
The prompt or prompts not to guide the image generation. If not defined, one has to pass
300
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
301
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
302
prompt_embeds (`torch.FloatTensor`, *optional*):
303
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
304
provided, text embeddings will be generated from `prompt` input argument.
305
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
306
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
307
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
308
argument.
309
"""
310
if prompt is not None and isinstance(prompt, str):
311
batch_size = 1
312
elif prompt is not None and isinstance(prompt, list):
313
batch_size = len(prompt)
314
else:
315
batch_size = prompt_embeds.shape[0]
316
317
if prompt_embeds is None:
318
text_inputs = self.tokenizer(
319
prompt,
320
padding="max_length",
321
max_length=self.tokenizer.model_max_length,
322
truncation=True,
323
return_tensors="pt",
324
)
325
text_input_ids = text_inputs.input_ids
326
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
327
328
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
329
text_input_ids, untruncated_ids
330
):
331
removed_text = self.tokenizer.batch_decode(
332
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
333
)
334
logger.warning(
335
"The following part of your input was truncated because CLIP can only handle sequences up to"
336
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
337
)
338
339
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
340
attention_mask = text_inputs.attention_mask.to(device)
341
else:
342
attention_mask = None
343
344
prompt_embeds = self.text_encoder(
345
text_input_ids.to(device),
346
attention_mask=attention_mask,
347
)
348
prompt_embeds = prompt_embeds[0]
349
350
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
351
352
bs_embed, seq_len, _ = prompt_embeds.shape
353
# duplicate text embeddings for each generation per prompt, using mps friendly method
354
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
355
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
356
357
# get unconditional embeddings for classifier free guidance
358
if do_classifier_free_guidance and negative_prompt_embeds is None:
359
uncond_tokens: List[str]
360
if negative_prompt is None:
361
uncond_tokens = [""] * batch_size
362
elif type(prompt) is not type(negative_prompt):
363
raise TypeError(
364
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
365
f" {type(prompt)}."
366
)
367
elif isinstance(negative_prompt, str):
368
uncond_tokens = [negative_prompt]
369
elif batch_size != len(negative_prompt):
370
raise ValueError(
371
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
372
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
373
" the batch size of `prompt`."
374
)
375
else:
376
uncond_tokens = negative_prompt
377
378
max_length = prompt_embeds.shape[1]
379
uncond_input = self.tokenizer(
380
uncond_tokens,
381
padding="max_length",
382
max_length=max_length,
383
truncation=True,
384
return_tensors="pt",
385
)
386
387
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
388
attention_mask = uncond_input.attention_mask.to(device)
389
else:
390
attention_mask = None
391
392
negative_prompt_embeds = self.text_encoder(
393
uncond_input.input_ids.to(device),
394
attention_mask=attention_mask,
395
)
396
negative_prompt_embeds = negative_prompt_embeds[0]
397
398
if do_classifier_free_guidance:
399
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
400
seq_len = negative_prompt_embeds.shape[1]
401
402
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
403
404
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
405
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
406
407
# For classifier free guidance, we need to do two forward passes.
408
# Here we concatenate the unconditional and text embeddings into a single batch
409
# to avoid doing two forward passes
410
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
411
412
return prompt_embeds
413
414
def run_safety_checker(self, image, device, dtype):
415
if self.safety_checker is not None:
416
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
417
image, has_nsfw_concept = self.safety_checker(
418
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
419
)
420
else:
421
has_nsfw_concept = None
422
return image, has_nsfw_concept
423
424
def decode_latents(self, latents):
425
latents = 1 / self.vae.config.scaling_factor * latents
426
image = self.vae.decode(latents).sample
427
image = (image / 2 + 0.5).clamp(0, 1)
428
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
429
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
430
return image
431
432
def prepare_extra_step_kwargs(self, generator, eta):
433
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
434
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
435
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
436
# and should be between [0, 1]
437
438
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
439
extra_step_kwargs = {}
440
if accepts_eta:
441
extra_step_kwargs["eta"] = eta
442
443
# check if the scheduler accepts generator
444
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
445
if accepts_generator:
446
extra_step_kwargs["generator"] = generator
447
return extra_step_kwargs
448
449
def check_inputs(
450
self,
451
prompt,
452
height,
453
width,
454
callback_steps,
455
negative_prompt=None,
456
prompt_embeds=None,
457
negative_prompt_embeds=None,
458
):
459
if height % 8 != 0 or width % 8 != 0:
460
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
461
462
if (callback_steps is None) or (
463
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
464
):
465
raise ValueError(
466
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
467
f" {type(callback_steps)}."
468
)
469
470
if prompt is not None and prompt_embeds is not None:
471
raise ValueError(
472
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
473
" only forward one of the two."
474
)
475
elif prompt is None and prompt_embeds is None:
476
raise ValueError(
477
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
478
)
479
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
480
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
481
482
if negative_prompt is not None and negative_prompt_embeds is not None:
483
raise ValueError(
484
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
485
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
486
)
487
488
if prompt_embeds is not None and negative_prompt_embeds is not None:
489
if prompt_embeds.shape != negative_prompt_embeds.shape:
490
raise ValueError(
491
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
492
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
493
f" {negative_prompt_embeds.shape}."
494
)
495
496
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
497
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
498
if isinstance(generator, list) and len(generator) != batch_size:
499
raise ValueError(
500
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
501
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
502
)
503
504
if latents is None:
505
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
506
else:
507
latents = latents.to(device)
508
509
# scale the initial noise by the standard deviation required by the scheduler
510
latents = latents * self.scheduler.init_noise_sigma
511
return latents
512
513
@torch.no_grad()
514
@replace_example_docstring(EXAMPLE_DOC_STRING)
515
def __call__(
516
self,
517
prompt: Union[str, List[str]] = None,
518
height: Optional[int] = None,
519
width: Optional[int] = None,
520
num_inference_steps: int = 50,
521
guidance_scale: float = 7.5,
522
negative_prompt: Optional[Union[str, List[str]]] = None,
523
num_images_per_prompt: Optional[int] = 1,
524
eta: float = 0.0,
525
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
526
latents: Optional[torch.FloatTensor] = None,
527
prompt_embeds: Optional[torch.FloatTensor] = None,
528
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
529
output_type: Optional[str] = "pil",
530
return_dict: bool = True,
531
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
532
callback_steps: int = 1,
533
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
534
):
535
r"""
536
Function invoked when calling the pipeline for generation.
537
538
Args:
539
prompt (`str` or `List[str]`, *optional*):
540
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
541
instead.
542
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
543
The height in pixels of the generated image.
544
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
545
The width in pixels of the generated image.
546
num_inference_steps (`int`, *optional*, defaults to 50):
547
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
548
expense of slower inference.
549
guidance_scale (`float`, *optional*, defaults to 7.5):
550
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
551
`guidance_scale` is defined as `w` of equation 2. of [Imagen
552
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
553
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
554
usually at the expense of lower image quality.
555
negative_prompt (`str` or `List[str]`, *optional*):
556
The prompt or prompts not to guide the image generation. If not defined, one has to pass
557
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
558
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
559
num_images_per_prompt (`int`, *optional*, defaults to 1):
560
The number of images to generate per prompt.
561
eta (`float`, *optional*, defaults to 0.0):
562
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
563
[`schedulers.DDIMScheduler`], will be ignored for others.
564
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
565
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
566
to make generation deterministic.
567
latents (`torch.FloatTensor`, *optional*):
568
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
569
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
570
tensor will ge generated by sampling using the supplied random `generator`.
571
prompt_embeds (`torch.FloatTensor`, *optional*):
572
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
573
provided, text embeddings will be generated from `prompt` input argument.
574
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
575
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
576
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
577
argument.
578
output_type (`str`, *optional*, defaults to `"pil"`):
579
The output format of the generate image. Choose between
580
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
581
return_dict (`bool`, *optional*, defaults to `True`):
582
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
583
plain tuple.
584
callback (`Callable`, *optional*):
585
A function that will be called every `callback_steps` steps during inference. The function will be
586
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
587
callback_steps (`int`, *optional*, defaults to 1):
588
The frequency at which the `callback` function will be called. If not specified, the callback will be
589
called at every step.
590
cross_attention_kwargs (`dict`, *optional*):
591
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
592
`self.processor` in
593
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
594
595
Examples:
596
597
Returns:
598
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
599
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
600
When returning a tuple, the first element is a list with the generated images, and the second element is a
601
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
602
(nsfw) content, according to the `safety_checker`.
603
"""
604
# 0. Default height and width to unet
605
height = height or self.unet.config.sample_size * self.vae_scale_factor
606
width = width or self.unet.config.sample_size * self.vae_scale_factor
607
608
# 1. Check inputs. Raise error if not correct
609
self.check_inputs(
610
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
611
)
612
613
# 2. Define call parameters
614
if prompt is not None and isinstance(prompt, str):
615
batch_size = 1
616
elif prompt is not None and isinstance(prompt, list):
617
batch_size = len(prompt)
618
else:
619
batch_size = prompt_embeds.shape[0]
620
621
device = self._execution_device
622
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
623
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
624
# corresponds to doing no classifier free guidance.
625
do_classifier_free_guidance = guidance_scale > 1.0
626
627
# 3. Encode input prompt
628
prompt_embeds = self._encode_prompt(
629
prompt,
630
device,
631
num_images_per_prompt,
632
do_classifier_free_guidance,
633
negative_prompt,
634
prompt_embeds=prompt_embeds,
635
negative_prompt_embeds=negative_prompt_embeds,
636
)
637
638
# 4. Prepare timesteps
639
self.scheduler.set_timesteps(num_inference_steps, device=device)
640
timesteps = self.scheduler.timesteps
641
642
# 5. Prepare latent variables
643
num_channels_latents = self.unet.in_channels
644
latents = self.prepare_latents(
645
batch_size * num_images_per_prompt,
646
num_channels_latents,
647
height,
648
width,
649
prompt_embeds.dtype,
650
device,
651
generator,
652
latents,
653
)
654
655
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
656
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
657
658
# 7. Denoising loop
659
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
660
with self.progress_bar(total=num_inference_steps) as progress_bar:
661
for i, t in enumerate(timesteps):
662
# expand the latents if we are doing classifier free guidance
663
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
664
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
665
666
# predict the noise residual
667
noise_pred = self.unet(
668
latent_model_input,
669
t,
670
encoder_hidden_states=prompt_embeds,
671
cross_attention_kwargs=cross_attention_kwargs,
672
).sample
673
674
# perform guidance
675
if do_classifier_free_guidance:
676
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
677
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
678
679
# compute the previous noisy sample x_t -> x_t-1
680
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
681
682
# call the callback, if provided
683
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
684
progress_bar.update()
685
if callback is not None and i % callback_steps == 0:
686
callback(i, t, latents)
687
688
if output_type == "latent":
689
image = latents
690
has_nsfw_concept = None
691
elif output_type == "pil":
692
# 8. Post-processing
693
image = self.decode_latents(latents)
694
695
# 9. Run safety checker
696
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
697
698
# 10. Convert to PIL
699
image = self.numpy_to_pil(image)
700
else:
701
# 8. Post-processing
702
image = self.decode_latents(latents)
703
704
# 9. Run safety checker
705
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
706
707
# Offload last model to CPU
708
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
709
self.final_offload_hook.offload()
710
711
if not return_dict:
712
return (image, has_nsfw_concept)
713
714
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
715
716