Transformer_Project/ ├── datasets/ │ └── coco_dataset/ # Data for training and testing │ ├── train2017/ # Training images │ ├── val2017/ # Validation images (optional) │ └── annotations/ # Annotation files (captions_train2017.json, etc.) ├── models/ │ ├── text_to_image_transformer_original.pth # Original trained model │ └── text_to_image_transformer_optimized.pth # Optimized model after training ├── results/ │ ├── generated_image_pretrained.png # Image generated by the original model │ └── generated_images_optimized/ # Images generated by the optimized model └── notebooks/ ├── data_preprocessing.ipynb # Data loading and preprocessing ├── model_training.ipynb # Model building and training ├── optimization_training.ipynb # Optimized training with mixed precision ├── generate_images.ipynb # Generating images with trained models └── analysis_and_visualization.ipynb # Analyzing and visualizing results
加载 COCO 数据集,挂载 Google Drive,并对数据进行预处理,使其适用于模型训练:
Load the COCO dataset, mount Google Drive, and preprocess data for training:
from google.colab import drive
from pycocotools.coco import COCO
from transformers import BertTokenizer, BertModel
import torch
from PIL import Image
import os
# 挂载 Google Drive (Mount Google Drive)
drive.mount('/content/drive')
# 设置数据集路径 (Set dataset path)
dataset_path = '/content/drive/MyDrive/Transformer_Project/datasets/coco_dataset/'
annotation_file = os.path.join(dataset_path, 'annotations/captions_train2017.json')
# 使用 pycocotools 加载 COCO 数据集 (Load COCO dataset using pycocotools)
coco = COCO(annotation_file)
# 设置 Bert Tokenizer (Set Bert Tokenizer)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# 定义数据加载和预处理函数 (Define data loading and preprocessing function)
def load_image_and_preprocess(img_id):
# 获取图像信息 (Get image info)
img_info = coco.loadImgs(img_id)[0]
img_path = os.path.join(dataset_path, 'train2017', img_info['file_name'])
# 从本地路径加载图像 (Load image from local path)
img = Image.open(img_path)
# 确保图像为 RGB 模式 (Ensure image is in RGB mode)
if img.mode != 'RGB':
img = img.convert('RGB')
# 获取图像的注释 (Get annotations for the image)
ann_ids = coco.getAnnIds(imgIds=img_id)
annotations = coco.loadAnns(ann_ids)
# 提取描述并进行分词 (Extract description and tokenize)
description = annotations[0]['caption'] if len(annotations) > 0 else ""
tokens = tokenizer(description, padding="max_length", truncation=True, return_tensors="pt")
return img, tokens
# 加载并预处理示例 (Load and preprocess example)
data_example = coco.getImgIds()[0]
img, tokens = load_image_and_preprocess(data_example)
定义 Transformer 模型,用于文本到图像的生成:
Define and instantiate the Transformer model for text-to-image generation:
# 定义 Transformer 模型用于文本生成图像 (Define Transformer model for text-to-image generation)
class TextToImageTransformer(torch.nn.Module):
def __init__(self):
super(TextToImageTransformer, self).__init__()
self.text_encoder = BertModel.from_pretrained('bert-base-uncased')
self.gan_decoder = torch.nn.Sequential(
torch.nn.Linear(768, 1024),
torch.nn.ReLU(),
torch.nn.Linear(1024, 256*256*3),
torch.nn.Tanh()
)
def forward(self, text_inputs):
text_features = self.text_encoder(**text_inputs).last_hidden_state.mean(dim=1)
generated_image = self.gan_decoder(text_features)
return generated_image.view(-1, 3, 256, 256)
# 设置设备并实例化模型 (Set device and instantiate model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TextToImageTransformer().to(device)
model.eval()
使用预训练模型从给定的文本生成图像:
Use the pre-trained model to generate an image from a given text prompt:
# 输入文本 (Input text)
input_text = "A beautiful sunset over the mountains."
# 对输入文本进行分词 (Tokenize the input text)
tokens = tokenizer(input_text, padding="max_length", truncation=True, return_tensors="pt")
# 生成图像 (Generate image)
with torch.no_grad():
generated_image = model(tokens.to(device))
# 将图像张量转换为 numpy 数组并保存 (Convert image tensor to numpy array and save it)
import numpy as np
image_array = (np.clip(generated_image[0].permute(1, 2, 0).cpu().numpy(), 0, 1) * 255).astype(np.uint8)
image = Image.fromarray(image_array)
image.save('/content/drive/MyDrive/Transformer_Project/results/generated_image_pretrained.png')
# 显示生成的图像 (Display the generated image)
image.show()
通过使用自定义 DataLoader 和混合精度训练来优化模型性能:
Define and train the model with mixed precision, using a custom DataLoader for optimized training:
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
import requests
import io
import numpy as np
from torchvision import transforms
# 定义 COCO 数据集的自定义 Dataset (Define custom Dataset for COCO data)
class COCODatasetAPI(Dataset):
def __init__(self, coco, tokenizer, transform=None):
self.coco = coco
self.tokenizer = tokenizer
self.img_ids = list(coco.imgs.keys())
self.transform = transform
self.resize_transform = transforms.Resize((256, 256)) # 将图像大小调整为 256x256 (Resize images to 256x256)
def __len__(self):
return len(self.img_ids)
def __getitem__(self, idx):
img_id = self.img_ids[idx]
# 加载图像 (Load image)
img_info = self.coco.loadImgs(img_id)[0]
img_url = img_info['coco_url']
response = requests.get(img_url)
img = Image.open(io.BytesIO(response.content))
# 确保图像为 RGB 模式 (Ensure image is in RGB mode)
if img.mode != 'RGB':
img = img.convert('RGB')
# 调整图像大小 (Resize image)
img = self.resize_transform(img)
if self.transform:
img = self.transform(img)
# 加载描述并进行分词 (Load caption and tokenize)
ann_ids = self.coco.getAnnIds(imgIds=img_id)
annotations = self.coco.loadAnns(ann_ids)
description = annotations[0]['caption'] if len(annotations) > 0 else ""
tokens = self.tokenizer(description, padding="max_length", truncation=True, return_tensors="pt")
# 准备数据 (Prepare data)
tokens = {k: v.squeeze() for k, v in tokens.items()} # 移除批次维度 (Remove batch dimension)
img = torch.tensor(np.array(img)).permute(2, 0, 1) / 255.0 # 归一化图像 (Normalize image)
return tokens, img
# 创建数据集和 DataLoader (Create the dataset and DataLoader)
dataset_api = COCODatasetAPI(coco, tokenizer)
dataloader_api = DataLoader(dataset_api, batch_size=4, shuffle=True)
# 使用混合精度进行优化训练 (Optimized training with mixed precision)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = torch.nn.MSELoss()
scaler = GradScaler()
for epoch in range(10):
for tokens, target_images in dataloader_api:
tokens = {k: v.to(device) for k, v in tokens.items()}
target_images = target_images.to(device)
optimizer.zero_grad()
with autocast():
output_images = model(tokens)
loss = loss_fn(output_images, target_images)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# 保存优化后的模型 (Save optimized model)
torch.save(model.state_dict(), '/content/drive/MyDrive/Transformer_Project/models/text_to_image_transformer_optimized.pth')
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
import os
import numpy as np
from torchvision import transforms
# 定义 COCO 数据集的自定义 Dataset (Define custom Dataset for COCO data)
class COCODatasetLocal(Dataset):
def __init__(self, coco, tokenizer, transform=None):
self.coco = coco
self.tokenizer = tokenizer
self.img_ids = list(coco.imgs.keys())
self.transform = transform
self.resize_transform = transforms.Resize((256, 256)) # 将图像大小调整为 256x256 (Resize images to 256x256)
def __len__(self):
return len(self.img_ids)
def __getitem__(self, idx):
img_id = self.img_ids[idx]
# 加载图像 (Load image)
img_info = self.coco.loadImgs(img_id)[0]
img_path = os.path.join(dataset_path, 'train2017', img_info['file_name'])
img = Image.open(img_path)
# 确保图像为 RGB 模式 (Ensure image is in RGB mode)
if img.mode != 'RGB':
img = img.convert('RGB')
# 调整图像大小 (Resize image)
img = self.resize_transform(img)
if self.transform:
img = self.transform(img)
# 加载描述并进行分词 (Load caption and tokenize)
ann_ids = self.coco.getAnnIds(imgIds=img_id)
annotations = self.coco.loadAnns(ann_ids)
description = annotations[0]['caption'] if len(annotations) > 0 else ""
tokens = self.tokenizer(description, padding="max_length", truncation=True, return_tensors="pt")
# 准备数据 (Prepare data)
tokens = {k: v.squeeze() for k, v in tokens.items()} # 移除批次维度 (Remove batch dimension)
img = torch.tensor(np.array(img)).permute(2, 0, 1) / 255.0 # 归一化图像 (Normalize image)
return tokens, img
# 创建数据集和 DataLoader (Create the dataset and DataLoader)
dataset_local = COCODatasetLocal(coco, tokenizer)
dataloader_local = DataLoader(dataset_local, batch_size=4, shuffle=True)
# 使用混合精度进行优化训练 (Optimized training with mixed precision)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = torch.nn.MSELoss()
scaler = GradScaler()
for epoch in range(10):
for tokens, target_images in dataloader_local:
tokens = {k: v.to(device) for k, v in tokens.items()}
target_images = target_images.to(device)
optimizer.zero_grad()
with autocast():
output_images = model(tokens)
loss = loss_fn(output_images, target_images)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# 保存优化后的模型 (Save optimized model)
torch.save(model.state_dict(), '/content/drive/MyDrive/Transformer_Project/models/text_to_image_transformer_optimized.pth')
使用优化后的模型从给定的文本生成图像,并保存结果以便进行比较:
Use the optimized model to generate an image from a given text prompt, and save the result for comparison:
# 加载优化后的模型权重 (Load optimized model weights)
model.load_state_dict(torch.load('/content/drive/MyDrive/Transformer_Project/models/text_to_image_transformer_optimized.pth'))
model.eval()
# 生成图像 (Generate image)
input_text = "A beautiful sunset over the mountains."
tokens = tokenizer(input_text, padding="max_length", truncation=True, return_tensors="pt")
with torch.no_grad():
generated_image_optimized = model(tokens.to(device))
# 保存生成的图像 (Save the generated image)
image_array_optimized = (np.clip(generated_image_optimized[0].permute(1, 2, 0).cpu().numpy(), 0, 1) * 255).astype(np.uint8)
image_optimized = Image.fromarray(image_array_optimized)
image_optimized.save('/content/drive/MyDrive/Transformer_Project/results/generated_images_optimized/generated_image_optimized.png')
将训练前后的生成结果进行可视化,以比较模型的改进情况:
Visualize the generated results before and after training to compare the improvements:
import matplotlib.pyplot as plt
# 加载图像 (Load images)
img_pretrained = Image.open('/content/drive/MyDrive/Transformer_Project/results/generated_image_pretrained.png')
img_optimized = Image.open('/content/drive/MyDrive/Transformer_Project/results/generated_images_optimized/generated_image_optimized.png')
# 可视化对比 (Visualize comparison)
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(img_pretrained)
ax[0].set_title("Pre-trained Model")
ax[0].axis('off')
ax[1].imshow(img_optimized)
ax[1].set_title("Optimized Model")
ax[1].axis('off')
plt.suptitle('Comparison of Generated Images')
plt.show()