Github Link
https://github.com/DiT-3D/DiT-3D/blob/main/train.py
*๋งค์ฐ๋งค์ฐ ๊ธ์ด ๊ธด ์ด์ฅ๋ฌธ์ ๋๋ค..!
๊ฐ ์ฝ๋๋ณ๋ก ์์ฒญ ์์ธํ๊ฒ ๋ฆฌ๋ทฐํ๊ณ , ์ต๋ํ ํ๋ฆ์ ๋ฐ๋ผ์ ์ฝ๋์ ์์์ ๋ถ์ฌ์ ์ค๋ช ํ์์ต๋๋ค.
*DiT-3D ์ฝ๋๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ค๋ช ํ๊ณ ์์ง๋ง, dataloader๋ฅผ ์ ์ธํ ๋๋จธ์ง ๋ฆฌ๋ทฐ๋ 2D ๊ธฐ๋ฐ์ DDPM or Diffusion Transformer ์ฝ๋์ ํ๋ฆ์ผ๋ก ์ดํดํ์ ๋ ๋ฌด๊ดํฉ๋๋ค.
Contents
๐ฆDataLoader Code
# train.py line 565
train_dataset, _ = get_dataset(opt.dataroot, opt.npoints, opt.category)
## get_dataset function
def get_dataset(dataroot, npoints,category):
tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
categories=category.split(','), split='train',
tr_sample_size=npoints, # 2048
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True)
te_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
categories=category.split(','), split='val',
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.,
normalize_per_shape=False,
normalize_std_per_axis=False,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return tr_dataset, te_dataset
# tr_sample_size = 2048
# te_sample_size = 2048
class ShapeNet15kPointClouds(Uniform15KPC):
def __init__(self, root_dir="data/ShapeNetCore.v2.PC15k",
categories=['airplane'], tr_sample_size=10000, te_sample_size=2048,
split='train', scale=1., normalize_per_shape=False,
normalize_std_per_axis=False, box_per_shape=False,
random_subsample=False,
all_points_mean=None, all_points_std=None,
use_mask=False):
self.root_dir = root_dir
self.split = split
assert self.split in ['train', 'test', 'val']
self.tr_sample_size = tr_sample_size
self.te_sample_size = te_sample_size
self.cates = categories
if 'all' in categories:
self.synset_ids = list(cate_to_synsetid.values())
else:
self.synset_ids = [cate_to_synsetid[c] for c in self.cates]
# assert 'v2' in root_dir, "Only supporting v2 right now."
self.gravity_axis = 1
self.display_axis_order = [0, 2, 1]
super(ShapeNet15kPointClouds, self).__init__(
root_dir, self.synset_ids,
tr_sample_size=tr_sample_size,
te_sample_size=te_sample_size,
split=split, scale=scale,
normalize_per_shape=normalize_per_shape, box_per_shape=box_per_shape,
normalize_std_per_axis=normalize_std_per_axis,
random_subsample=random_subsample,
all_points_mean=all_points_mean, all_points_std=all_points_std,
input_dim=3, use_mask=use_mask)
- https://github.com/DiT-3D/DiT-3D/blob/main/datasets/shapenet_data_pc.py#L201
- → ShapeNet15kPointClouds class๋ Uniform15KPC class๋ฅผ ์์๋ฐ์.
class Uniform15KPC(Dataset):
def __init__(self, root_dir, subdirs, tr_sample_size=10000,
te_sample_size=10000, split='train', scale=1.,
normalize_per_shape=False, box_per_shape=False,
random_subsample=False,
normalize_std_per_axis=False,
all_points_mean=None, all_points_std=None,
input_dim=3, use_mask=False):
self.root_dir = root_dir
self.split = split
self.in_tr_sample_size = tr_sample_size
self.in_te_sample_size = te_sample_size
self.subdirs = subdirs
self.scale = scale
self.random_subsample = random_subsample
self.input_dim = input_dim
self.use_mask = use_mask
self.box_per_shape = box_per_shape
if use_mask:
self.mask_transform = PointCloudMasks(radius=5, elev=5, azim=90)
self.all_cate_mids = []
self.cate_idx_lst = []
self.all_points = []
for cate_idx, subd in enumerate(self.subdirs):
# NOTE: [subd] here is synset id
sub_path = os.path.join(root_dir, subd, self.split)
if not os.path.isdir(sub_path):
print("Directory missing : %s" % sub_path)
continue
all_mids = []
for x in os.listdir(sub_path):
if not x.endswith('.npy'):
continue
all_mids.append(os.path.join(self.split, x[:-len('.npy')]))
# NOTE: [mid] contains the split: i.e. "train/<mid>" or "val/<mid>" or "test/<mid>"
for mid in all_mids:
# obj_fname = os.path.join(sub_path, x)
obj_fname = os.path.join(root_dir, subd, mid + ".npy")
try:
point_cloud = np.load(obj_fname) # (15k, 3)
except:
continue
assert point_cloud.shape[0] == 15000
self.all_points.append(point_cloud[np.newaxis, ...])
self.cate_idx_lst.append(cate_idx)
self.all_cate_mids.append((subd, mid))
# Shuffle the index deterministically (based on the number of examples)
self.shuffle_idx = list(range(len(self.all_points)))
random.Random(38383).shuffle(self.shuffle_idx)
self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
self.all_points = [self.all_points[i] for i in self.shuffle_idx]
self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx]
# Normalization
self.all_points = np.concatenate(self.all_points) # (N, 15000, 3)
self.normalize_per_shape = normalize_per_shape
self.normalize_std_per_axis = normalize_std_per_axis
if all_points_mean is not None and all_points_std is not None: # using loaded dataset stats
self.all_points_mean = all_points_mean
self.all_points_std = all_points_std
elif self.normalize_per_shape: # per shape normalization
B, N = self.all_points.shape[:2]
self.all_points_mean = self.all_points.mean(axis=1).reshape(B, 1, input_dim)
if normalize_std_per_axis:
self.all_points_std = self.all_points.reshape(B, N, -1).std(axis=1).reshape(B, 1, input_dim)
else:
self.all_points_std = self.all_points.reshape(B, -1).std(axis=1).reshape(B, 1, 1)
elif self.box_per_shape:
B, N = self.all_points.shape[:2]
self.all_points_mean = self.all_points.min(axis=1).reshape(B, 1, input_dim)
self.all_points_std = self.all_points.max(axis=1).reshape(B, 1, input_dim) - self.all_points.min(axis=1).reshape(B, 1, input_dim)
else: # normalize across the dataset
self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
if normalize_std_per_axis:
self.all_points_std = self.all_points.reshape(-1, input_dim).std(axis=0).reshape(1, 1, input_dim)
else:
self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
if self.box_per_shape:
self.all_points = self.all_points - 0.5
self.train_points = self.all_points[:, :10000]
self.test_points = self.all_points[:, 10000:]
self.tr_sample_size = min(10000, tr_sample_size)
self.te_sample_size = min(5000, te_sample_size)
print("Total number of data:%d" % len(self.train_points))
print("Min number of points: (train)%d (test)%d"
% (self.tr_sample_size, self.te_sample_size))
assert self.scale == 1, "Scale (!= 1) is deprecated"
def get_pc_stats(self, idx):
if self.normalize_per_shape or self.box_per_shape:
m = self.all_points_mean[idx].reshape(1, self.input_dim)
s = self.all_points_std[idx].reshape(1, -1)
return m, s
return self.all_points_mean.reshape(1, -1), self.all_points_std.reshape(1, -1)
def renormalize(self, mean, std):
self.all_points = self.all_points * self.all_points_std + self.all_points_mean
self.all_points_mean = mean
self.all_points_std = std
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
self.train_points = self.all_points[:, :10000]
self.test_points = self.all_points[:, 10000:]
def __len__(self):
return len(self.train_points)
def __getitem__(self, idx):
tr_out = self.train_points[idx]
if self.random_subsample:
tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size)
else:
tr_idxs = np.arange(self.tr_sample_size)
tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float()
te_out = self.test_points[idx]
if self.random_subsample:
te_idxs = np.random.choice(te_out.shape[0], self.te_sample_size)
else:
te_idxs = np.arange(self.te_sample_size)
te_out = torch.from_numpy(te_out[te_idxs, :]).float()
m, s = self.get_pc_stats(idx)
cate_idx = self.cate_idx_lst[idx]
sid, mid = self.all_cate_mids[idx]
out = {
'idx': idx,
'train_points': tr_out,
'test_points': te_out,
'mean': m, 'std': s, 'cate_idx': cate_idx,
'sid': sid, 'mid': mid
}
if self.use_mask:
# masked = torch.from_numpy(self.mask_transform(self.all_points[idx]))
# ss = min(masked.shape[0], self.in_tr_sample_size//2)
# masked = masked[:ss]
#
# tr_mask = torch.ones_like(masked)
# masked = torch.cat([masked, torch.zeros(self.in_tr_sample_size - ss, 3)],dim=0)#F.pad(masked, (self.in_tr_sample_size-masked.shape[0], 0), "constant", 0)
#
# tr_mask = torch.cat([tr_mask, torch.zeros(self.in_tr_sample_size- ss, 3)],dim=0)#F.pad(tr_mask, (self.in_tr_sample_size-tr_mask.shape[0], 0), "constant", 0)
# out['train_points_masked'] = masked
# out['train_masks'] = tr_mask
tr_mask = self.mask_transform(tr_out)
out['train_masks'] = tr_mask
return out
- PointFlow ๋
ผ๋ฌธ์ ๊ทผ๊ฑฐํ์ฌ pre-processing์ ์งํํ๊ณ ์์.
- self.all_points ⇒ (N, 15000, 3)
- (1) Normalization: Points axis๋ฅผ ๊ธฐ์ค์ผ๋ก, mean๊ณผ std๋ฅผ ๊ณ์ฐํ์ฌ ์๋์ ๊ฐ์ด ์ ์ฉ
self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
- (2) Training๊ณผ Testingํ ๋ 2048๊ฐ์ points๋ฅผ randomํ๊ฒ samplingํจ.
def __getitem__(self, idx): tr_out = self.train_points[idx] if self.random_subsample: # True์. tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size) else: tr_idxs = np.arange(self.tr_sample_size) tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float() te_out = self.test_points[idx] if self.random_subsample: te_idxs = np.random.choice(te_out.shape[0], self.te_sample_size) else: te_idxs = np.arange(self.te_sample_size) te_out = torch.from_numpy(te_out[te_idxs, :]).float() m, s = self.get_pc_stats(idx) # m์ mean # s๋ std๊ฐ out = { 'idx': idx, 'train_points': tr_out, 'test_points': te_out, 'mean': m, 'std': s, 'cate_idx': cate_idx, 'sid': sid, 'mid': mid } return out
์๋ฌธ์ 1๊ฐ์ง
*๊ทผ๋ฐ train-test๋ฅผ ๊ตฌ๋ถํ ๋, point๋ฅผ ๊ธฐ์ค์ผ๋ก ํ๋?
- ์ฒ์ ๋ณธ point cloud ์ขํ๋ผ์ ๊ด์ฐฎ์๊ฑด๊ฐ?
self.train_points = self.all_points[:, :10000]
self.test_points = self.all_points[:, 10000:]
๐ Model Training
# <https://github.com/DiT-3D/DiT-3D/blob/cfcc16b62004d735e935d14e0e8dc2d00a96041d/train.py#L410>
class Model(nn.Module):
def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
super(Model, self).__init__()
self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type)
# Check DiT3D_models_WindAttn
if args.window_size > 0:
self.model = DiT3D_models_WindAttn[args.model_type](pretrained=args.use_pretrained,
input_size=args.voxel_size,
window_size=args.window_size,
window_block_indexes=args.window_block_indexes,
num_classes=args.num_classes
)
else:
self.model = DiT3D_models[args.model_type](pretrained=args.use_pretrained,
input_size=args.voxel_size,
num_classes=args.num_classes
)
def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0)
def all_kl(self, x0, y, clip_denoised=True):
total_bpd_b, vals_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(self._denoise, x0, y, clip_denoised)
return {
'total_bpd_b': total_bpd_b,
'terms_bpd': vals_bt,
'prior_bpd_b': prior_bpd_b,
'mse_bt':mse_bt
}
def _denoise(self, data, t, y):
B, D, N= data.shape
assert data.dtype == torch.float
assert t.shape == torch.Size([B]) and t.dtype == torch.int64
out = self.model(data, t, y)
assert out.shape == torch.Size([B, D, N])
return out
def get_loss_iter(self, data, noises=None, y=None):
B, D, N = data.shape # [16, 3, 2048]
t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)
if noises is not None:
noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)
losses = self.diffusion.p_losses(
denoise_fn=self._denoise, data_start=data, t=t, noise=noises, y=y)
assert losses.shape == t.shape == torch.Size([B])
return losses
def gen_samples(self, shape, device, y, noise_fn=torch.randn,
clip_denoised=True,
keep_running=False):
return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, y=y, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
def gen_sample_traj(self, shape, device, y, freq, noise_fn=torch.randn,
clip_denoised=True,keep_running=False):
return self.diffusion.p_sample_loop_trajectory(self._denoise, shape=shape, device=device, y=y, noise_fn=noise_fn, freq=freq,
clip_denoised=clip_denoised,
keep_running=keep_running)
def train(self):
self.model.train()
def eval(self):
self.model.eval()
def multi_gpu_wrapper(self, f):
self.model = f(self.model)
- get_loss_iter ํจ์๋ฅผ ํตํด์ ํ๋ จ
- self.diffusion → GaussianDiffusion class
# p loss ๋ถ๋ถ (denoising)
def p_losses(self, denoise_fn, data_start, t, noise=None, y=None):
"""
Training loss calculation
"""
B, D, N = data_start.shape
assert t.shape == torch.Size([B])
if noise is None:
noise = torch.randn(data_start.shape, dtype=data_start.dtype, device=data_start.device)
assert noise.shape == data_start.shape and noise.dtype == data_start.dtype
# ๋ฐ์ ํจ์ ์์ (diffusion)
data_t = self.q_sample(x_start=data_start, t=t, noise=noise)
if self.loss_type == 'mse':
# predict the noise instead of x_start. seems to be weighted naturally like SNR
eps_recon = denoise_fn(data_t, t, y)
assert data_t.shape == data_start.shape
assert eps_recon.shape == torch.Size([B, D, N])
assert eps_recon.shape == data_start.shape
losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
elif self.loss_type == 'kl':
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, y=y, clip_denoised=False,
return_pred_xstart=False)
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == torch.Size([B])
return losses
# q_sample: noise๊ฐ ์
ํ์ง x๊ฐ ์ถ์ถ๋จ
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = torch.randn(x_start.shape, device=x_start.device)
assert noise.shape == x_start.shape
# ๋ชจ๋ธ initialization ๋ ๋ ๋ง๋ค์ด์ง ๊ฐ ํ์ฉ
## alphas: 1 - beta
## self.sqrt_alphas_cumprod: sqrt(t์์ ๊น์ง์ alpha cumprod ์ฐ์ฐ๋ ๊ฐ๋ค)
## self.sqrt_one_minus_alphas_cumprod: sqrt(1-cum_alphas)
#### x_start * self.sqrt_alphas + sqrt(1-alphas) * noises
#### DDPM์ 11๋ฒ์งธ ์์(?)
return (
self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
)
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
- p_losses → q_sample → p_losses (mse ์ฐ์ฐ)
- q_sample๋ก ๋ถํฐ, x input ๊ฐ์ noise๋ก ์ ํ.
- denoise_fn(data_t, t, y)๋ก ๋ค์ด๊ฐ
- Model class์ _denoise function → DiT-3D ๋ธ๋ญ forward๋ก ๋ค์ด๊ฐ
DiT3D_models_WindAttn = {
'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
}
def DiT_S_2(pretrained=False, **kwargs):
return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
- https://github.com/DiT-3D/DiT-3D/blob/main/models/dit3d_window_attn.py
- DiT Class๋ฅผ return ๋ฐ์
def forward(self, x, t, y):
"""
Forward pass of DiT.
x: (N, C, P) tensor of spatial inputs (point clouds or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# Voxelization
features, coords = x, x
x, voxel_coords = self.voxelization(features, coords)
x = self.x_embedder(x)
x = x + self.pos_embed
t = self.t_embedder(t)
y = self.y_embedder(y, self.training)
c = t + y
for block in self.blocks:
x = block(x, c)
x = self.final_layer(x, c)
x = self.unpatchify_voxels(x)
# Devoxelization
x = F.trilinear_devoxelize(x, voxel_coords, self.input_size, self.training)
return x
- self.voxelization → voxel ํํ๋ก ๋ณํ
- Nx3 → VxVxVx3 (default V: 32)
- self.x_embedder: Conv3D block
- (B, C, X, Y, Z) → (B, E, X, Y, Z)
- C๋ coordinates 3์ ๋ํ๋ด๋ ๊ฐ
- self.t_embedder: t embedding
- t๊ฐ์ batch size ๋งํผ์ ์ ์ ํํ์
def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ # <https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py> half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq) return t_emb
- ์์ ์ฝ๋๋ฅผ ํตํด์, sincos ํจ์ ์ ์ฉ + 2๊ฐ์ MLP & GeLU layer ํต๊ณผ
- self.y_embedder: y(class) embedding
- self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
- block์ DiTBlock class๋ก ๋ค๊ฐ
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size,
num_heads,
mlp_ratio=4.0,
use_rel_pos=False,
rel_pos_zero_init=True,
window_size=0,
use_residual_block=False,
input_size=None, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(
hidden_size,
num_heads=num_heads,
qkv_bias=True,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size, window_size),
)
self.input_size = input_size
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
# approx_gelu = lambda: nn.GELU(approximate="tanh")
approx_gelu = lambda: nn.GELU() # for torch 1.7.1
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.window_size = window_size
def forward(self, x, c):
# c๊ฐ condition
# MLP๋ฅผ ํตํด์ shift์ scale factors๋ฅผ ์์ฑ
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
shortcut = x
x = self.norm1(x)
B = x.shape[0]
x = x.reshape(B, self.input_size[0], self.input_size[1], self.input_size[2], -1)
# Window partition
if self.window_size > 0:
X, Y, Z = x.shape[1], x.shape[2], x.shape[3]
x, pad_xyz = window_partition(x, self.window_size)
x = x.reshape(B, self.input_size[0] * self.input_size[1] * self.input_size[2], -1)
x = modulate(x, shift_msa, scale_msa)
x = x.reshape(B, self.input_size[0], self.input_size[1], self.input_size[2], -1)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_xyz, (X, Y, Z))
x = x.reshape(B, self.input_size[0] * self.input_size[1] * self.input_size[2], -1)
x = shortcut + gate_msa.unsqueeze(1) * x
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
- self.adaLN_modulation(c).chunk(6, dim=1)
- ํตํด์, scale๊ณผ shift factors๋ฅผ ์์ฑ
- DiTBlock์ด ๋ค ๋๋๋ฉด, DiT class์์ final layer๋ก ๋ค์ด๊ฐ
- out_channels: 3์ด๋ผ์, coordinates ํํ๋ก ๋์ด
-
class FinalLayer(nn.Module): """ The final layer of DiT. """ def __init__(self, hidden_size, patch_size, out_channels): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * patch_size * out_channels, bias=True) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) ) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x # (p, p, p, c)
- Loss ๊ณ์ฐ์, p_losses ํจ์์์
- losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
- ์์ ๊ฐ์ด ๊ณ์ฐ๋จ. (noise ์ ๋๋ฅผ ์์ธก; MSE)
๐ฒModel Inference
gen = model.gen_samples(x.shape, gpu, new_y_chain(gpu,y.shape[0],opt.num_classes), clip_denoised=False).detach().cpu()
→ Model class์์ gen_samples ํจ์๋ฅผ ํตํด์ ๋ค์ด๊ฐ
# gen_samples ํจ์
def gen_samples(self, shape, device, y, noise_fn=torch.randn,
clip_denoised=True, # False์
keep_running=False):
# GaussianDiffusion class
return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, y=y, noise_fn=noise_fn,
clip_denoised=clip_denoised,
keep_running=keep_running)
# p_sample_loop ํจ์
def p_sample_loop(self, denoise_fn, shape, device, y,
noise_fn=torch.randn, clip_denoised=True, keep_running=False):
"""
Generate samples
keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps
"""
assert isinstance(shape, (tuple, list))
#### noise_fn = torch.randn
# img_t = ๋๋ค ๋
ธ์ด์ฆ point cloud
img_t = noise_fn(size=shape, dtype=torch.float, device=device)
# keep_running = False
# self.num_timesteps = 1000
for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, y=y,
clip_denoised=clip_denoised, return_pred_xstart=False)
assert img_t.shape == shape
return img_t
- Samplingํ ๋, timesteps = 1000์ผ๋ก ์ค์ ํจ.
- p_sample_loop ํจ์๋, p_sample์ timesteps ๋งํผ ๋ฐ๋ณต ์ํํ ๊ฒ.
def p_sample(self, denoise_fn, data, t, noise_fn, y, clip_denoised=False, return_pred_xstart=False):
"""
Sample from the model
"""
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, y=y, clip_denoised=clip_denoised,
return_pred_xstart=True)
noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device)
assert noise.shape == data.shape
# no noise when t == 0
nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(data.shape) - 1))
# Algorithm 2 -> Sampling
sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
assert sample.shape == pred_xstart.shape
return (sample, pred_xstart) if return_pred_xstart else sample
# p_mean_variance ํจ์
def p_mean_variance(self, denoise_fn, data, t, y, clip_denoised: bool, return_pred_xstart: bool):
model_output = denoise_fn(data, t, y)
## self.model_var_type: fixedsmall
if self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas.to(data.device),
torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data)
model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data)
else:
raise NotImplementedError(self.model_var_type)
if self.model_mean_type == 'eps':
x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output)
if clip_denoised: # False
x_recon = torch.clamp(x_recon, -.5, .5)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t)
else:
raise NotImplementedError(self.loss_type)
assert model_mean.shape == x_recon.shape == data.shape
assert model_variance.shape == model_log_variance.shape == data.shape
if return_pred_xstart: # True์.
return model_mean, model_variance, model_log_variance, x_recon
else:
return model_mean, model_variance, model_log_variance
- self.model_var_type: fixedsmall
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.posterior_variance = posterior_variance # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
- posterior_variance๋ DDPM์์ ์ ์ํ๋ variance ๊ฐ
- (์ข ๋ ๋ํ ์ผํ๊ฒ๋, beta ๊ฐ์ด fixed ๋์ด์์ผ๋ฏ๋ก ์์ ๊ฐ์ด ์ ์ํ ์ ์์)
- posterior_log_variance_clipped: variance์ log๊ฐ
def _extract(a, t, x_shape):
"""
Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
bs, = t.shape
assert x_shape[0] == bs
out = torch.gather(a, 0, t)
assert out.shape == torch.Size([bs])
return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
- model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data)
- t๋ถ๋ถ์ ํด๋นํ๋ variance ์ถ์ถํ๋ ๋ถ๋ถ (์ธ ๊ฒ ๊ฐ์)
- self.model_mean_type: eps
- x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output)
- data: T ์์ ์ noised point cloud
- model_output: T ์์ noised point cloud์ ๋ํด์, ์ ํ์ง noise ์์ธก ๊ฐ
def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t - self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps )
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float() self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
- _predict_xstart_from_eps: Xt์ noise๊ฐ์ ํ์ฉํด์ X0๋ฅผ ์ ์ถํ๋ ๋จ๊ณ (์์ ์ด๋ฏธ์ง ์์ ์ฐธ๊ณ )
- Question: sqrt_recipm1_alphas_cumprod์ ๋ถ๋ชจ๊ฐ alphas_cumprod - 1 ์ธ ๊ฒ์ด์ง?
- ๋ญ๊ฐ ์ ์์์์ x0๋ฅผ ๋จ๊ธฐ๊ณ ๋ค ์ข๋ณ์ผ๋ก ๋๊ธฐ๋๊ฒ ๋ง๋ ์์์ธ ๊ฒ ๊ฐ์๋ฐ..
- ์..๋ถ๋ชจ๋ ๊ทธ๋ฅ alphas_cumprod์ด๊ณ (1/alphas_cumprod) - 1 ์..ใ ใ ใ ใ
- Question: sqrt_recipm1_alphas_cumprod์ ๋ถ๋ชจ๊ฐ alphas_cumprod - 1 ์ธ ๊ฒ์ด์ง?
- model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t)
-
def q_posterior_mean_variance(self, x_start, x_t, t): """ Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0) """ assert x_start.shape == x_t.shape posterior_mean = ( self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start + self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t ) posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape) posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape) assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == x_start.shape[0]) return posterior_mean, posterior_variance, posterior_log_variance_clipped ########################### self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
- model_mean = ์์ ์ด๋ฏธ์ง์ mean ๊ฐ์ ํด๋นํ๋ ๋ถ๋ถ (equation 7)
- x_start = X0์ ํด๋นํ๋๊ฐ
- ๋ง์ง๋ง, DDPM์ ๋์ค๋ algorithm 2 ์์์ ๊ทผ๊ฑฐํ์ฌ์ ์๋์ ๊ฐ์ด t-1๋ฒ์งธsampling์ ์๋ํจ.
- sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
DDPM์ ์๋ฆฌ๊ฐ, decoder (denoising) ๊ณผ์ ์์ encoder (diffusion)์ ํ๊ท ์ ์์ธกํ๋๋ก ํ๋ จ
- → ์ฆ decoder์ ๋ถํฌ๋ encoder์ ๊ฐ์์ผ ํ๋ฏ๋ก, encoder์ ๋ถํฌ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก
- ํ๊ท +๋ถ์ฐ*๋ ธ์ด์ฆ ํํ๊ฐ ๋จ
๐Inference ์์ฝ
p_sample_loop [img_t: ๋๋ค ๋ ธ์ด์ฆ point cloud]
→ p_sample [img_t์ ๋ ธ์ด์ฆ๋ฅผ T → T-1๋ก ์ ์ ์ค์ฌ๋๊ฐ๋ ๋จ๊ณ; Algorithm2]
→ p_mean_variance [t์์ ์ q๋ถํฌ์ Mean๊ณผ Variance ๊ณ์ฐ]
→ _predict_xstart_from_eps [Noise์ X_t๋ฅผ ๊ธฐ๋ฐ์ผ๋ก, X0๋ฅผ ์์ธก]
→ q_posterior_mean_variance [X0์ X_t๋ฅผ ๋ฐํ์ผ๋ก, q๋ถํฌ์ ํ๊ท ์ ์์ธก]