import os
import gc
import urllib.request
import torch
from transformers import XLMRobertaTokenizer, XLMRobertaForMaskedLM, CLIPTokenizer, CLIPTextModel
from diffusers import DiffusionPipeline
CHECKPOINTS = [
"https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_Chinese_clip_overnorm_over3_noln.ckpt",
"https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_French_clip_overnorm_over3_noln.ckpt",
"https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_Italian_clip_overnorm_over3_noln.ckpt",
"https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_Japanese_clip_overnorm_over3_noln.ckpt",
"https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_Spanish_clip_overnorm_over3_noln.ckpt",
"https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_sound2img_audioclip_us8k.ckpt"
]
LANGUAGE_PROMPTS = {
"French": "une voiture sur la plage",
}
def download_checkpoints(checkpoint_dir):
os.makedirs(checkpoint_dir, exist_ok=True)
for url in CHECKPOINTS:
filename = os.path.join(checkpoint_dir, os.path.basename(url))
if not os.path.exists(filename):
print(f"Downloading {filename}...")
urllib.request.urlretrieve(url, filename)
print(f"Downloaded {filename}")
else:
print(f"Checkpoint {filename} already exists, skipping download.")
return checkpoint_dir
def load_checkpoint(pipeline, checkpoint_path, device):
state_dict = torch.load(checkpoint_path, map_location=device)
state_dict = state_dict.get("state_dict", state_dict)
missing_keys, unexpected_keys = pipeline.unet.load_state_dict(state_dict, strict=False)
return pipeline
def generate_image(pipeline, prompt, device, output_path):
with torch.inference_mode():
image = pipeline(
prompt,
generator=torch.Generator(device=device).manual_seed(42),
num_inference_steps=50
).images[0]
image.save(output_path)
print(f"Image saved to {output_path}")
checkpoint_dir = download_checkpoints("./checkpoints_all/gluenet_checkpoint")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base", use_fast=False)
model = XLMRobertaForMaskedLM.from_pretrained("xlm-roberta-base").to(device)
inputs = tokenizer("Ceci est une phrase incomplète avec un [MASK].", return_tensors="pt").to(device)
with torch.inference_mode():
_ = model(**inputs)
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
pipeline = DiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
text_encoder=clip_text_encoder,
tokenizer=clip_tokenizer,
custom_pipeline="gluegen",
safety_checker=None
).to(device)
os.makedirs("outputs", exist_ok=True)
for language, prompt in LANGUAGE_PROMPTS.items():
checkpoint_file = f"gluenet_{language}_clip_overnorm_over3_noln.ckpt"
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)
try:
pipeline = load_checkpoint(pipeline, checkpoint_path, device)
output_path = f"outputs/gluegen_output_{language.lower()}.png"
generate_image(pipeline, prompt, device, output_path)
except Exception as e:
print(f"Error processing {language} model: {e}")
continue
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()