Path: blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
4044 views
# Copyright 2023 The HuggingFace Team. All rights reserved.1#2# Licensed under the Apache License, Version 2.0 (the "License");3# you may not use this file except in compliance with the License.4# You may obtain a copy of the License at5#6# http://www.apache.org/licenses/LICENSE-2.07#8# Unless required by applicable law or agreed to in writing, software9# distributed under the License is distributed on an "AS IS" BASIS,10# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.11# See the License for the specific language governing permissions and12# limitations under the License.1314import inspect15from typing import Any, Callable, Dict, List, Optional, Union1617import torch18from packaging import version19from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer2021from ...configuration_utils import FrozenDict22from ...models import AutoencoderKL, UNet2DConditionModel23from ...schedulers import KarrasDiffusionSchedulers24from ...utils import (25deprecate,26is_accelerate_available,27is_accelerate_version,28logging,29randn_tensor,30replace_example_docstring,31)32from ..pipeline_utils import DiffusionPipeline33from . import StableDiffusionPipelineOutput34from .safety_checker import StableDiffusionSafetyChecker353637logger = logging.get_logger(__name__) # pylint: disable=invalid-name3839EXAMPLE_DOC_STRING = """40Examples:41```py42>>> import torch43>>> from diffusers import StableDiffusionPipeline4445>>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)46>>> pipe = pipe.to("cuda")4748>>> prompt = "a photo of an astronaut riding a horse on mars"49>>> image = pipe(prompt).images[0]50```51"""525354class StableDiffusionPipeline(DiffusionPipeline):55r"""56Pipeline for text-to-image generation using Stable Diffusion.5758This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the59library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)6061Args:62vae ([`AutoencoderKL`]):63Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.64text_encoder ([`CLIPTextModel`]):65Frozen text-encoder. Stable Diffusion uses the text portion of66[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically67the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.68tokenizer (`CLIPTokenizer`):69Tokenizer of class70[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).71unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.72scheduler ([`SchedulerMixin`]):73A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of74[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].75safety_checker ([`StableDiffusionSafetyChecker`]):76Classification module that estimates whether generated images could be considered offensive or harmful.77Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.78feature_extractor ([`CLIPImageProcessor`]):79Model that extracts features from generated images to be used as inputs for the `safety_checker`.80"""81_optional_components = ["safety_checker", "feature_extractor"]8283def __init__(84self,85vae: AutoencoderKL,86text_encoder: CLIPTextModel,87tokenizer: CLIPTokenizer,88unet: UNet2DConditionModel,89scheduler: KarrasDiffusionSchedulers,90safety_checker: StableDiffusionSafetyChecker,91feature_extractor: CLIPImageProcessor,92requires_safety_checker: bool = True,93):94super().__init__()9596if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:97deprecation_message = (98f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"99f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "100"to update the config accordingly as leaving `steps_offset` might led to incorrect results"101" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"102" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"103" file"104)105deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)106new_config = dict(scheduler.config)107new_config["steps_offset"] = 1108scheduler._internal_dict = FrozenDict(new_config)109110if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:111deprecation_message = (112f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."113" `clip_sample` should be set to False in the configuration file. Please make sure to update the"114" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"115" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"116" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"117)118deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)119new_config = dict(scheduler.config)120new_config["clip_sample"] = False121scheduler._internal_dict = FrozenDict(new_config)122123if safety_checker is None and requires_safety_checker:124logger.warning(125f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"126" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"127" results in services or applications open to the public. Both the diffusers team and Hugging Face"128" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"129" it only for use-cases that involve analyzing network behavior or auditing its results. For more"130" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."131)132133if safety_checker is not None and feature_extractor is None:134raise ValueError(135"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"136" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."137)138139is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(140version.parse(unet.config._diffusers_version).base_version141) < version.parse("0.9.0.dev0")142is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64143if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:144deprecation_message = (145"The configuration file of the unet has set the default `sample_size` to smaller than"146" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"147" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"148" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"149" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"150" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"151" in the config might lead to incorrect results in future versions. If you have downloaded this"152" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"153" the `unet/config.json` file"154)155deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)156new_config = dict(unet.config)157new_config["sample_size"] = 64158unet._internal_dict = FrozenDict(new_config)159160self.register_modules(161vae=vae,162text_encoder=text_encoder,163tokenizer=tokenizer,164unet=unet,165scheduler=scheduler,166safety_checker=safety_checker,167feature_extractor=feature_extractor,168)169self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)170self.register_to_config(requires_safety_checker=requires_safety_checker)171172def enable_vae_slicing(self):173r"""174Enable sliced VAE decoding.175176When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several177steps. This is useful to save some memory and allow larger batch sizes.178"""179self.vae.enable_slicing()180181def disable_vae_slicing(self):182r"""183Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to184computing decoding in one step.185"""186self.vae.disable_slicing()187188def enable_vae_tiling(self):189r"""190Enable tiled VAE decoding.191192When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in193several steps. This is useful to save a large amount of memory and to allow the processing of larger images.194"""195self.vae.enable_tiling()196197def disable_vae_tiling(self):198r"""199Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to200computing decoding in one step.201"""202self.vae.disable_tiling()203204def enable_sequential_cpu_offload(self, gpu_id=0):205r"""206Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,207text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a208`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.209Note that offloading happens on a submodule basis. Memory savings are higher than with210`enable_model_cpu_offload`, but performance is lower.211"""212if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):213from accelerate import cpu_offload214else:215raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")216217device = torch.device(f"cuda:{gpu_id}")218219if self.device.type != "cpu":220self.to("cpu", silence_dtype_warnings=True)221torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)222223for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:224cpu_offload(cpu_offloaded_model, device)225226if self.safety_checker is not None:227cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)228229def enable_model_cpu_offload(self, gpu_id=0):230r"""231Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared232to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`233method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with234`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.235"""236if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):237from accelerate import cpu_offload_with_hook238else:239raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")240241device = torch.device(f"cuda:{gpu_id}")242243if self.device.type != "cpu":244self.to("cpu", silence_dtype_warnings=True)245torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)246247hook = None248for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:249_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)250251if self.safety_checker is not None:252_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)253254# We'll offload the last model manually.255self.final_offload_hook = hook256257@property258def _execution_device(self):259r"""260Returns the device on which the pipeline's models will be executed. After calling261`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module262hooks.263"""264if not hasattr(self.unet, "_hf_hook"):265return self.device266for module in self.unet.modules():267if (268hasattr(module, "_hf_hook")269and hasattr(module._hf_hook, "execution_device")270and module._hf_hook.execution_device is not None271):272return torch.device(module._hf_hook.execution_device)273return self.device274275def _encode_prompt(276self,277prompt,278device,279num_images_per_prompt,280do_classifier_free_guidance,281negative_prompt=None,282prompt_embeds: Optional[torch.FloatTensor] = None,283negative_prompt_embeds: Optional[torch.FloatTensor] = None,284):285r"""286Encodes the prompt into text encoder hidden states.287288Args:289prompt (`str` or `List[str]`, *optional*):290prompt to be encoded291device: (`torch.device`):292torch device293num_images_per_prompt (`int`):294number of images that should be generated per prompt295do_classifier_free_guidance (`bool`):296whether to use classifier free guidance or not297negative_prompt (`str` or `List[str]`, *optional*):298The prompt or prompts not to guide the image generation. If not defined, one has to pass299`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.300Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).301prompt_embeds (`torch.FloatTensor`, *optional*):302Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not303provided, text embeddings will be generated from `prompt` input argument.304negative_prompt_embeds (`torch.FloatTensor`, *optional*):305Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt306weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input307argument.308"""309if prompt is not None and isinstance(prompt, str):310batch_size = 1311elif prompt is not None and isinstance(prompt, list):312batch_size = len(prompt)313else:314batch_size = prompt_embeds.shape[0]315316if prompt_embeds is None:317text_inputs = self.tokenizer(318prompt,319padding="max_length",320max_length=self.tokenizer.model_max_length,321truncation=True,322return_tensors="pt",323)324text_input_ids = text_inputs.input_ids325untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids326327if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(328text_input_ids, untruncated_ids329):330removed_text = self.tokenizer.batch_decode(331untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]332)333logger.warning(334"The following part of your input was truncated because CLIP can only handle sequences up to"335f" {self.tokenizer.model_max_length} tokens: {removed_text}"336)337338if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:339attention_mask = text_inputs.attention_mask.to(device)340else:341attention_mask = None342343prompt_embeds = self.text_encoder(344text_input_ids.to(device),345attention_mask=attention_mask,346)347prompt_embeds = prompt_embeds[0]348349prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)350351bs_embed, seq_len, _ = prompt_embeds.shape352# duplicate text embeddings for each generation per prompt, using mps friendly method353prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)354prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)355356# get unconditional embeddings for classifier free guidance357if do_classifier_free_guidance and negative_prompt_embeds is None:358uncond_tokens: List[str]359if negative_prompt is None:360uncond_tokens = [""] * batch_size361elif type(prompt) is not type(negative_prompt):362raise TypeError(363f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="364f" {type(prompt)}."365)366elif isinstance(negative_prompt, str):367uncond_tokens = [negative_prompt]368elif batch_size != len(negative_prompt):369raise ValueError(370f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"371f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"372" the batch size of `prompt`."373)374else:375uncond_tokens = negative_prompt376377max_length = prompt_embeds.shape[1]378uncond_input = self.tokenizer(379uncond_tokens,380padding="max_length",381max_length=max_length,382truncation=True,383return_tensors="pt",384)385386if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:387attention_mask = uncond_input.attention_mask.to(device)388else:389attention_mask = None390391negative_prompt_embeds = self.text_encoder(392uncond_input.input_ids.to(device),393attention_mask=attention_mask,394)395negative_prompt_embeds = negative_prompt_embeds[0]396397if do_classifier_free_guidance:398# duplicate unconditional embeddings for each generation per prompt, using mps friendly method399seq_len = negative_prompt_embeds.shape[1]400401negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)402403negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)404negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)405406# For classifier free guidance, we need to do two forward passes.407# Here we concatenate the unconditional and text embeddings into a single batch408# to avoid doing two forward passes409prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])410411return prompt_embeds412413def run_safety_checker(self, image, device, dtype):414if self.safety_checker is not None:415safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)416image, has_nsfw_concept = self.safety_checker(417images=image, clip_input=safety_checker_input.pixel_values.to(dtype)418)419else:420has_nsfw_concept = None421return image, has_nsfw_concept422423def decode_latents(self, latents):424latents = 1 / self.vae.config.scaling_factor * latents425image = self.vae.decode(latents).sample426image = (image / 2 + 0.5).clamp(0, 1)427# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16428image = image.cpu().permute(0, 2, 3, 1).float().numpy()429return image430431def prepare_extra_step_kwargs(self, generator, eta):432# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature433# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.434# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502435# and should be between [0, 1]436437accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())438extra_step_kwargs = {}439if accepts_eta:440extra_step_kwargs["eta"] = eta441442# check if the scheduler accepts generator443accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())444if accepts_generator:445extra_step_kwargs["generator"] = generator446return extra_step_kwargs447448def check_inputs(449self,450prompt,451height,452width,453callback_steps,454negative_prompt=None,455prompt_embeds=None,456negative_prompt_embeds=None,457):458if height % 8 != 0 or width % 8 != 0:459raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")460461if (callback_steps is None) or (462callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)463):464raise ValueError(465f"`callback_steps` has to be a positive integer but is {callback_steps} of type"466f" {type(callback_steps)}."467)468469if prompt is not None and prompt_embeds is not None:470raise ValueError(471f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"472" only forward one of the two."473)474elif prompt is None and prompt_embeds is None:475raise ValueError(476"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."477)478elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):479raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")480481if negative_prompt is not None and negative_prompt_embeds is not None:482raise ValueError(483f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"484f" {negative_prompt_embeds}. Please make sure to only forward one of the two."485)486487if prompt_embeds is not None and negative_prompt_embeds is not None:488if prompt_embeds.shape != negative_prompt_embeds.shape:489raise ValueError(490"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"491f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"492f" {negative_prompt_embeds.shape}."493)494495def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):496shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)497if isinstance(generator, list) and len(generator) != batch_size:498raise ValueError(499f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"500f" size of {batch_size}. Make sure the batch size matches the length of the generators."501)502503if latents is None:504latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)505else:506latents = latents.to(device)507508# scale the initial noise by the standard deviation required by the scheduler509latents = latents * self.scheduler.init_noise_sigma510return latents511512@torch.no_grad()513@replace_example_docstring(EXAMPLE_DOC_STRING)514def __call__(515self,516prompt: Union[str, List[str]] = None,517height: Optional[int] = None,518width: Optional[int] = None,519num_inference_steps: int = 50,520guidance_scale: float = 7.5,521negative_prompt: Optional[Union[str, List[str]]] = None,522num_images_per_prompt: Optional[int] = 1,523eta: float = 0.0,524generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,525latents: Optional[torch.FloatTensor] = None,526prompt_embeds: Optional[torch.FloatTensor] = None,527negative_prompt_embeds: Optional[torch.FloatTensor] = None,528output_type: Optional[str] = "pil",529return_dict: bool = True,530callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,531callback_steps: int = 1,532cross_attention_kwargs: Optional[Dict[str, Any]] = None,533):534r"""535Function invoked when calling the pipeline for generation.536537Args:538prompt (`str` or `List[str]`, *optional*):539The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.540instead.541height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):542The height in pixels of the generated image.543width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):544The width in pixels of the generated image.545num_inference_steps (`int`, *optional*, defaults to 50):546The number of denoising steps. More denoising steps usually lead to a higher quality image at the547expense of slower inference.548guidance_scale (`float`, *optional*, defaults to 7.5):549Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).550`guidance_scale` is defined as `w` of equation 2. of [Imagen551Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >5521`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,553usually at the expense of lower image quality.554negative_prompt (`str` or `List[str]`, *optional*):555The prompt or prompts not to guide the image generation. If not defined, one has to pass556`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.557Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).558num_images_per_prompt (`int`, *optional*, defaults to 1):559The number of images to generate per prompt.560eta (`float`, *optional*, defaults to 0.0):561Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to562[`schedulers.DDIMScheduler`], will be ignored for others.563generator (`torch.Generator` or `List[torch.Generator]`, *optional*):564One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)565to make generation deterministic.566latents (`torch.FloatTensor`, *optional*):567Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image568generation. Can be used to tweak the same generation with different prompts. If not provided, a latents569tensor will ge generated by sampling using the supplied random `generator`.570prompt_embeds (`torch.FloatTensor`, *optional*):571Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not572provided, text embeddings will be generated from `prompt` input argument.573negative_prompt_embeds (`torch.FloatTensor`, *optional*):574Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt575weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input576argument.577output_type (`str`, *optional*, defaults to `"pil"`):578The output format of the generate image. Choose between579[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.580return_dict (`bool`, *optional*, defaults to `True`):581Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a582plain tuple.583callback (`Callable`, *optional*):584A function that will be called every `callback_steps` steps during inference. The function will be585called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.586callback_steps (`int`, *optional*, defaults to 1):587The frequency at which the `callback` function will be called. If not specified, the callback will be588called at every step.589cross_attention_kwargs (`dict`, *optional*):590A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under591`self.processor` in592[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).593594Examples:595596Returns:597[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:598[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.599When returning a tuple, the first element is a list with the generated images, and the second element is a600list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"601(nsfw) content, according to the `safety_checker`.602"""603# 0. Default height and width to unet604height = height or self.unet.config.sample_size * self.vae_scale_factor605width = width or self.unet.config.sample_size * self.vae_scale_factor606607# 1. Check inputs. Raise error if not correct608self.check_inputs(609prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds610)611612# 2. Define call parameters613if prompt is not None and isinstance(prompt, str):614batch_size = 1615elif prompt is not None and isinstance(prompt, list):616batch_size = len(prompt)617else:618batch_size = prompt_embeds.shape[0]619620device = self._execution_device621# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)622# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`623# corresponds to doing no classifier free guidance.624do_classifier_free_guidance = guidance_scale > 1.0625626# 3. Encode input prompt627prompt_embeds = self._encode_prompt(628prompt,629device,630num_images_per_prompt,631do_classifier_free_guidance,632negative_prompt,633prompt_embeds=prompt_embeds,634negative_prompt_embeds=negative_prompt_embeds,635)636637# 4. Prepare timesteps638self.scheduler.set_timesteps(num_inference_steps, device=device)639timesteps = self.scheduler.timesteps640641# 5. Prepare latent variables642num_channels_latents = self.unet.in_channels643latents = self.prepare_latents(644batch_size * num_images_per_prompt,645num_channels_latents,646height,647width,648prompt_embeds.dtype,649device,650generator,651latents,652)653654# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline655extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)656657# 7. Denoising loop658num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order659with self.progress_bar(total=num_inference_steps) as progress_bar:660for i, t in enumerate(timesteps):661# expand the latents if we are doing classifier free guidance662latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents663latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)664665# predict the noise residual666noise_pred = self.unet(667latent_model_input,668t,669encoder_hidden_states=prompt_embeds,670cross_attention_kwargs=cross_attention_kwargs,671).sample672673# perform guidance674if do_classifier_free_guidance:675noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)676noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)677678# compute the previous noisy sample x_t -> x_t-1679latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample680681# call the callback, if provided682if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):683progress_bar.update()684if callback is not None and i % callback_steps == 0:685callback(i, t, latents)686687if output_type == "latent":688image = latents689has_nsfw_concept = None690elif output_type == "pil":691# 8. Post-processing692image = self.decode_latents(latents)693694# 9. Run safety checker695image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)696697# 10. Convert to PIL698image = self.numpy_to_pil(image)699else:700# 8. Post-processing701image = self.decode_latents(latents)702703# 9. Run safety checker704image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)705706# Offload last model to CPU707if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:708self.final_offload_hook.offload()709710if not return_dict:711return (image, has_nsfw_concept)712713return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)714715716