๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ

AI/CV project

[DiT-3D or DDPM Code ๋ถ„์„]

๋ฐ˜์‘ํ˜•

Github Link

https://github.com/DiT-3D/DiT-3D/blob/main/train.py

 

DiT-3D/train.py at main · DiT-3D/DiT-3D

๐Ÿ”ฅ๐Ÿ”ฅ๐Ÿ”ฅOfficial Codebase of "DiT-3D: Exploring Plain Diffusion Transformers for 3D Shape Generation" - DiT-3D/DiT-3D

github.com

*๋งค์šฐ๋งค์šฐ ๊ธ€์ด ๊ธด ์ดˆ์žฅ๋ฌธ์ž…๋‹ˆ๋‹ค..!

๊ฐ ์ฝ”๋“œ๋ณ„๋กœ ์—„์ฒญ ์ƒ์„ธํ•˜๊ฒŒ ๋ฆฌ๋ทฐํ–ˆ๊ณ , ์ตœ๋Œ€ํ•œ ํ๋ฆ„์— ๋”ฐ๋ผ์„œ ์ฝ”๋“œ์™€ ์ˆ˜์‹์„ ๋ถ™์—ฌ์„œ ์„ค๋ช…ํ•˜์˜€์Šต๋‹ˆ๋‹ค.

*DiT-3D ์ฝ”๋“œ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์„ค๋ช…ํ•˜๊ณ  ์žˆ์ง€๋งŒ, dataloader๋ฅผ ์ œ์™ธํ•œ ๋‚˜๋จธ์ง€ ๋ฆฌ๋ทฐ๋Š” 2D ๊ธฐ๋ฐ˜์˜ DDPM or Diffusion Transformer ์ฝ”๋“œ์˜ ํ๋ฆ„์œผ๋กœ ์ดํ•ดํ•˜์…”๋„ ๋ฌด๊ด€ํ•ฉ๋‹ˆ๋‹ค.


Contents

- ๐ŸฆŽDataLoader Code

- ๐ŸModel Training Code

- ๐ŸฒModel Inference Code

- ๐Ÿ‰Inference Summary


๐ŸฆŽDataLoader Code

# train.py line 565
train_dataset, _ = get_dataset(opt.dataroot, opt.npoints, opt.category)

## get_dataset function
def get_dataset(dataroot, npoints,category):
    tr_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
        categories=category.split(','), split='train',
        tr_sample_size=npoints, # 2048
        te_sample_size=npoints,
        scale=1.,
        normalize_per_shape=False,
        normalize_std_per_axis=False,
        random_subsample=True)
    te_dataset = ShapeNet15kPointClouds(root_dir=dataroot,
        categories=category.split(','), split='val',
        tr_sample_size=npoints,
        te_sample_size=npoints,
        scale=1.,
        normalize_per_shape=False,
        normalize_std_per_axis=False,
        all_points_mean=tr_dataset.all_points_mean,
        all_points_std=tr_dataset.all_points_std,
    )
    return tr_dataset, te_dataset
# tr_sample_size = 2048
# te_sample_size = 2048
class ShapeNet15kPointClouds(Uniform15KPC):
    def __init__(self, root_dir="data/ShapeNetCore.v2.PC15k",
                 categories=['airplane'], tr_sample_size=10000, te_sample_size=2048,
                 split='train', scale=1., normalize_per_shape=False,
                 normalize_std_per_axis=False, box_per_shape=False,
                 random_subsample=False,
                 all_points_mean=None, all_points_std=None,
                 use_mask=False):
        self.root_dir = root_dir
        self.split = split
        assert self.split in ['train', 'test', 'val']
        self.tr_sample_size = tr_sample_size
        self.te_sample_size = te_sample_size
        self.cates = categories
        if 'all' in categories:
            self.synset_ids = list(cate_to_synsetid.values())
        else:
            self.synset_ids = [cate_to_synsetid[c] for c in self.cates]

        # assert 'v2' in root_dir, "Only supporting v2 right now."
        self.gravity_axis = 1
        self.display_axis_order = [0, 2, 1]

        super(ShapeNet15kPointClouds, self).__init__(
            root_dir, self.synset_ids,
            tr_sample_size=tr_sample_size,
            te_sample_size=te_sample_size,
            split=split, scale=scale,
            normalize_per_shape=normalize_per_shape, box_per_shape=box_per_shape,
            normalize_std_per_axis=normalize_std_per_axis,
            random_subsample=random_subsample,
            all_points_mean=all_points_mean, all_points_std=all_points_std,
            input_dim=3, use_mask=use_mask)
class Uniform15KPC(Dataset):
    def __init__(self, root_dir, subdirs, tr_sample_size=10000,
                 te_sample_size=10000, split='train', scale=1.,
                 normalize_per_shape=False, box_per_shape=False,
                 random_subsample=False,
                 normalize_std_per_axis=False,
                 all_points_mean=None, all_points_std=None,
                 input_dim=3, use_mask=False):
        self.root_dir = root_dir
        self.split = split
        self.in_tr_sample_size = tr_sample_size
        self.in_te_sample_size = te_sample_size
        self.subdirs = subdirs
        self.scale = scale
        self.random_subsample = random_subsample
        self.input_dim = input_dim
        self.use_mask = use_mask
        self.box_per_shape = box_per_shape
        if use_mask:
            self.mask_transform = PointCloudMasks(radius=5, elev=5, azim=90)

        self.all_cate_mids = []
        self.cate_idx_lst = []
        self.all_points = []
        for cate_idx, subd in enumerate(self.subdirs):
            # NOTE: [subd] here is synset id
            sub_path = os.path.join(root_dir, subd, self.split)
            if not os.path.isdir(sub_path):
                print("Directory missing : %s" % sub_path)
                continue

            all_mids = []
            for x in os.listdir(sub_path):
                if not x.endswith('.npy'):
                    continue
                all_mids.append(os.path.join(self.split, x[:-len('.npy')]))

            # NOTE: [mid] contains the split: i.e. "train/<mid>" or "val/<mid>" or "test/<mid>"
            for mid in all_mids:
                # obj_fname = os.path.join(sub_path, x)
                obj_fname = os.path.join(root_dir, subd, mid + ".npy")
                try:
                    point_cloud = np.load(obj_fname)  # (15k, 3)

                except:
                    continue

                assert point_cloud.shape[0] == 15000
                self.all_points.append(point_cloud[np.newaxis, ...])
                self.cate_idx_lst.append(cate_idx)
                self.all_cate_mids.append((subd, mid))

        # Shuffle the index deterministically (based on the number of examples)
        self.shuffle_idx = list(range(len(self.all_points)))
        random.Random(38383).shuffle(self.shuffle_idx)
        self.cate_idx_lst = [self.cate_idx_lst[i] for i in self.shuffle_idx]
        self.all_points = [self.all_points[i] for i in self.shuffle_idx]
        self.all_cate_mids = [self.all_cate_mids[i] for i in self.shuffle_idx]

        # Normalization
        self.all_points = np.concatenate(self.all_points)  # (N, 15000, 3)
        self.normalize_per_shape = normalize_per_shape
        self.normalize_std_per_axis = normalize_std_per_axis
        if all_points_mean is not None and all_points_std is not None:  # using loaded dataset stats
            self.all_points_mean = all_points_mean
            self.all_points_std = all_points_std
        elif self.normalize_per_shape:  # per shape normalization
            B, N = self.all_points.shape[:2]
            self.all_points_mean = self.all_points.mean(axis=1).reshape(B, 1, input_dim)
            if normalize_std_per_axis:
                self.all_points_std = self.all_points.reshape(B, N, -1).std(axis=1).reshape(B, 1, input_dim)
            else:
                self.all_points_std = self.all_points.reshape(B, -1).std(axis=1).reshape(B, 1, 1)
        elif self.box_per_shape:
            B, N = self.all_points.shape[:2]
            self.all_points_mean = self.all_points.min(axis=1).reshape(B, 1, input_dim)

            self.all_points_std = self.all_points.max(axis=1).reshape(B, 1, input_dim) - self.all_points.min(axis=1).reshape(B, 1, input_dim)

        else:  # normalize across the dataset
            self.all_points_mean = self.all_points.reshape(-1, input_dim).mean(axis=0).reshape(1, 1, input_dim)
            if normalize_std_per_axis:
                self.all_points_std = self.all_points.reshape(-1, input_dim).std(axis=0).reshape(1, 1, input_dim)
            else:
                self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)

        self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
        if self.box_per_shape:
            self.all_points = self.all_points - 0.5
        self.train_points = self.all_points[:, :10000]
        self.test_points = self.all_points[:, 10000:]

        self.tr_sample_size = min(10000, tr_sample_size)
        self.te_sample_size = min(5000, te_sample_size)
        print("Total number of data:%d" % len(self.train_points))
        print("Min number of points: (train)%d (test)%d"
              % (self.tr_sample_size, self.te_sample_size))
        assert self.scale == 1, "Scale (!= 1) is deprecated"

    def get_pc_stats(self, idx):
        if self.normalize_per_shape or self.box_per_shape:
            m = self.all_points_mean[idx].reshape(1, self.input_dim)
            s = self.all_points_std[idx].reshape(1, -1)
            return m, s

        return self.all_points_mean.reshape(1, -1), self.all_points_std.reshape(1, -1)

    def renormalize(self, mean, std):
        self.all_points = self.all_points * self.all_points_std + self.all_points_mean
        self.all_points_mean = mean
        self.all_points_std = std
        self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
        self.train_points = self.all_points[:, :10000]
        self.test_points = self.all_points[:, 10000:]

    def __len__(self):
        return len(self.train_points)

    def __getitem__(self, idx):
        tr_out = self.train_points[idx]
        if self.random_subsample:
            tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size)
        else:
            tr_idxs = np.arange(self.tr_sample_size)
        tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float()

        te_out = self.test_points[idx]
        if self.random_subsample:
            te_idxs = np.random.choice(te_out.shape[0], self.te_sample_size)
        else:
            te_idxs = np.arange(self.te_sample_size)
        te_out = torch.from_numpy(te_out[te_idxs, :]).float()

        m, s = self.get_pc_stats(idx)
        cate_idx = self.cate_idx_lst[idx]
        sid, mid = self.all_cate_mids[idx]

        out = {
            'idx': idx,
            'train_points': tr_out,
            'test_points': te_out,
            'mean': m, 'std': s, 'cate_idx': cate_idx,
            'sid': sid, 'mid': mid
        }

        if self.use_mask:
            # masked = torch.from_numpy(self.mask_transform(self.all_points[idx]))
            # ss = min(masked.shape[0], self.in_tr_sample_size//2)
            # masked = masked[:ss]
            #
            # tr_mask = torch.ones_like(masked)
            # masked = torch.cat([masked, torch.zeros(self.in_tr_sample_size - ss, 3)],dim=0)#F.pad(masked, (self.in_tr_sample_size-masked.shape[0], 0), "constant", 0)
            #
            # tr_mask =  torch.cat([tr_mask, torch.zeros(self.in_tr_sample_size- ss, 3)],dim=0)#F.pad(tr_mask, (self.in_tr_sample_size-tr_mask.shape[0], 0), "constant", 0)
            # out['train_points_masked'] = masked
            # out['train_masks'] = tr_mask
            tr_mask = self.mask_transform(tr_out)
            out['train_masks'] = tr_mask

        return out
  • PointFlow ๋…ผ๋ฌธ์— ๊ทผ๊ฑฐํ•˜์—ฌ pre-processing์„ ์ง„ํ–‰ํ•˜๊ณ  ์žˆ์Œ.
    • self.all_points ⇒ (N, 15000, 3)
    • (1) Normalization: Points axis๋ฅผ ๊ธฐ์ค€์œผ๋กœ, mean๊ณผ std๋ฅผ ๊ณ„์‚ฐํ•˜์—ฌ ์•„๋ž˜์™€ ๊ฐ™์ด ์ ์šฉ
    self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
    
    • (2) Training๊ณผ Testingํ•  ๋•Œ 2048๊ฐœ์˜ points๋ฅผ randomํ•˜๊ฒŒ samplingํ•จ.
     def __getitem__(self, idx):
        tr_out = self.train_points[idx]
        if self.random_subsample: # True์ž„.
            tr_idxs = np.random.choice(tr_out.shape[0], self.tr_sample_size)
        else:
            tr_idxs = np.arange(self.tr_sample_size)
        tr_out = torch.from_numpy(tr_out[tr_idxs, :]).float()
    
        te_out = self.test_points[idx]
        if self.random_subsample:
            te_idxs = np.random.choice(te_out.shape[0], self.te_sample_size)
        else:
            te_idxs = np.arange(self.te_sample_size)
        te_out = torch.from_numpy(te_out[te_idxs, :]).float()
    
        m, s = self.get_pc_stats(idx)
        # m์€ mean
        # s๋Š” std๊ฐ’
    
        out = {
            'idx': idx,
            'train_points': tr_out,
            'test_points': te_out,
            'mean': m, 'std': s, 'cate_idx': cate_idx,
            'sid': sid, 'mid': mid
        }
    
        return out

์˜๋ฌธ์  1๊ฐ€์ง€

*๊ทผ๋ฐ train-test๋ฅผ ๊ตฌ๋ถ„ํ•  ๋•Œ, point๋ฅผ ๊ธฐ์ค€์œผ๋กœ ํ•˜๋‚˜?

  • ์ฒ˜์Œ ๋ณธ point cloud ์ขŒํ‘œ๋ผ์„œ ๊ดœ์ฐฎ์€๊ฑด๊ฐ€?
self.train_points = self.all_points[:, :10000]
self.test_points = self.all_points[:, 10000:]

๐Ÿ Model Training

# <https://github.com/DiT-3D/DiT-3D/blob/cfcc16b62004d735e935d14e0e8dc2d00a96041d/train.py#L410>
class Model(nn.Module):
    def __init__(self, args, betas, loss_type: str, model_mean_type: str, model_var_type:str):
        super(Model, self).__init__()
        self.diffusion = GaussianDiffusion(betas, loss_type, model_mean_type, model_var_type)
				
				# Check DiT3D_models_WindAttn
        if args.window_size > 0:
            self.model = DiT3D_models_WindAttn[args.model_type](pretrained=args.use_pretrained, 
                                                   input_size=args.voxel_size, 
                                                   window_size=args.window_size, 
                                                   window_block_indexes=args.window_block_indexes, 
                                                   num_classes=args.num_classes
                                                )
        else:
            self.model = DiT3D_models[args.model_type](pretrained=args.use_pretrained, 
                                                    input_size=args.voxel_size, 
                                                    num_classes=args.num_classes
                                                    )

		
    def prior_kl(self, x0):
        return self.diffusion._prior_bpd(x0)

    def all_kl(self, x0, y, clip_denoised=True):
        total_bpd_b, vals_bt, prior_bpd_b, mse_bt =  self.diffusion.calc_bpd_loop(self._denoise, x0, y, clip_denoised)

        return {
            'total_bpd_b': total_bpd_b,
            'terms_bpd': vals_bt,
            'prior_bpd_b': prior_bpd_b,
            'mse_bt':mse_bt
        }

    def _denoise(self, data, t, y):
        B, D, N= data.shape
        assert data.dtype == torch.float
        assert t.shape == torch.Size([B]) and t.dtype == torch.int64

        out = self.model(data, t, y)

        assert out.shape == torch.Size([B, D, N])
        return out

    def get_loss_iter(self, data, noises=None, y=None):
        B, D, N = data.shape                           # [16, 3, 2048]
        t = torch.randint(0, self.diffusion.num_timesteps, size=(B,), device=data.device)

        if noises is not None:
            noises[t!=0] = torch.randn((t!=0).sum(), *noises.shape[1:]).to(noises)

        losses = self.diffusion.p_losses(
            denoise_fn=self._denoise, data_start=data, t=t, noise=noises, y=y)
        assert losses.shape == t.shape == torch.Size([B])
        return losses

    def gen_samples(self, shape, device, y, noise_fn=torch.randn,
                    clip_denoised=True,
                    keep_running=False):
        return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, y=y, noise_fn=noise_fn,
                                            clip_denoised=clip_denoised,
                                            keep_running=keep_running)

    def gen_sample_traj(self, shape, device, y, freq, noise_fn=torch.randn,
                    clip_denoised=True,keep_running=False):
        return self.diffusion.p_sample_loop_trajectory(self._denoise, shape=shape, device=device, y=y, noise_fn=noise_fn, freq=freq,
                                                       clip_denoised=clip_denoised,
                                                       keep_running=keep_running)

    def train(self):
        self.model.train()

    def eval(self):
        self.model.eval()

    def multi_gpu_wrapper(self, f):
        self.model = f(self.model)
  • get_loss_iter ํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด์„œ ํ›ˆ๋ จ
    • self.diffusion → GaussianDiffusion class
# p loss ๋ถ€๋ถ„ (denoising)
def p_losses(self, denoise_fn, data_start, t, noise=None, y=None):
    """
    Training loss calculation
    """
    B, D, N = data_start.shape
    assert t.shape == torch.Size([B])

    if noise is None:
        noise = torch.randn(data_start.shape, dtype=data_start.dtype, device=data_start.device)
    assert noise.shape == data_start.shape and noise.dtype == data_start.dtype
		
		# ๋ฐ‘์— ํ•จ์ˆ˜ ์žˆ์Œ (diffusion)
    data_t = self.q_sample(x_start=data_start, t=t, noise=noise)

    if self.loss_type == 'mse':
        # predict the noise instead of x_start. seems to be weighted naturally like SNR
        eps_recon = denoise_fn(data_t, t, y)
        assert data_t.shape == data_start.shape
        assert eps_recon.shape == torch.Size([B, D, N])
        assert eps_recon.shape == data_start.shape
        losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
    elif self.loss_type == 'kl':
        losses = self._vb_terms_bpd(
            denoise_fn=denoise_fn, data_start=data_start, data_t=data_t, t=t, y=y, clip_denoised=False,
            return_pred_xstart=False)
    else:
        raise NotImplementedError(self.loss_type)

    assert losses.shape == torch.Size([B])
    return losses
    
    
# q_sample: noise๊ฐ€ ์ž…ํ˜€์ง„ x๊ฐ’ ์ถ”์ถœ๋จ
def q_sample(self, x_start, t, noise=None):
    """
    Diffuse the data (t == 0 means diffused for 1 step)
    """
    if noise is None:
        noise = torch.randn(x_start.shape, device=x_start.device)
    assert noise.shape == x_start.shape
    
    # ๋ชจ๋ธ initialization ๋  ๋•Œ ๋งŒ๋“ค์–ด์ง„ ๊ฐ’ ํ™œ์šฉ
    ## alphas: 1 - beta
    ## self.sqrt_alphas_cumprod: sqrt(t์‹œ์ ๊นŒ์ง€์˜ alpha cumprod ์—ฐ์‚ฐ๋œ ๊ฐ’๋“ค)
    ## self.sqrt_one_minus_alphas_cumprod: sqrt(1-cum_alphas)
    #### x_start * self.sqrt_alphas + sqrt(1-alphas) * noises
    #### DDPM์˜ 11๋ฒˆ์งธ ์ˆ˜์‹(?)
    return (
            self._extract(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
            self._extract(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise
    )

def _extract(a, t, x_shape):
    """
    Extract some coefficients at specified timesteps,
    then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    bs, = t.shape
    assert x_shape[0] == bs
    out = torch.gather(a, 0, t)
    assert out.shape == torch.Size([bs])
    return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
  • p_losses → q_sample → p_losses (mse ์—ฐ์‚ฐ)
    • q_sample๋กœ ๋ถ€ํ„ฐ, x input ๊ฐ’์„ noise๋กœ ์ž…ํž˜. 
    • denoise_fn(data_t, t, y)๋กœ ๋“ค์–ด๊ฐ
      • Model class์˜ _denoise function DiT-3D ๋ธ”๋Ÿญ forward๋กœ ๋“ค์–ด๊ฐ

q_sample์—์„œ ์ผ์–ด๋‚˜๋Š” ์ผ; x0๋กœ๋ถ€ํ„ฐ xt๋ฅผ ์ƒ์„ฑ (noise๋ฅผ ์ถ”๊ฐ€ํ•ด์„œ)

DiT3D_models_WindAttn = {
    'DiT-XL/2': DiT_XL_2,  'DiT-XL/4': DiT_XL_4,  'DiT-XL/8': DiT_XL_8,
    'DiT-L/2':  DiT_L_2,   'DiT-L/4':  DiT_L_4,   'DiT-L/8':  DiT_L_8,
    'DiT-B/2':  DiT_B_2,   'DiT-B/4':  DiT_B_4,   'DiT-B/8':  DiT_B_8,
    'DiT-S/2':  DiT_S_2,   'DiT-S/4':  DiT_S_4,   'DiT-S/8':  DiT_S_8,
}

def DiT_S_2(pretrained=False, **kwargs):
    return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
def forward(self, x, t, y):
    """
    Forward pass of DiT.
    x: (N, C, P) tensor of spatial inputs (point clouds or latent representations of images)
    t: (N,) tensor of diffusion timesteps
    y: (N,) tensor of class labels
    """

    # Voxelization
    features, coords = x, x
    x, voxel_coords = self.voxelization(features, coords)

    x = self.x_embedder(x) 
    x = x + self.pos_embed

    t = self.t_embedder(t)                  
    y = self.y_embedder(y, self.training)    
    c = t + y                           

    for block in self.blocks:
        x = block(x, c)                      
    x = self.final_layer(x, c)
    x = self.unpatchify_voxels(x)           

    # Devoxelization
    x = F.trilinear_devoxelize(x, voxel_coords, self.input_size, self.training)

    return x
  • self.voxelization → voxel ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜
    • Nx3 → VxVxVx3 (default V: 32)
  • self.x_embedder: Conv3D block
    • (B, C, X, Y, Z) → (B, E, X, Y, Z)
    • C๋Š” coordinates 3์„ ๋‚˜ํƒ€๋‚ด๋Š” ๊ฐ’
  • self.t_embedder: t embedding
    • t๊ฐ’์€ batch size ๋งŒํผ์˜ ์ •์ˆ˜ ํ˜•ํƒœ์ž„
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # <https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py>
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding
    
    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb
    
    • ์œ„์˜ ์ฝ”๋“œ๋ฅผ ํ†ตํ•ด์„œ, sincos ํ•จ์ˆ˜ ์ ์šฉ + 2๊ฐœ์˜ MLP & GeLU layer ํ†ต๊ณผ
  • self.y_embedder: y(class) embedding
    • self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
  • block์€ DiTBlock class๋กœ ๋“ค๊ฐ
class DiTBlock(nn.Module):
    """
    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self, hidden_size, 
                 num_heads, 
                 mlp_ratio=4.0, 
                 use_rel_pos=False,
                 rel_pos_zero_init=True,
                 window_size=0,
                 use_residual_block=False,
                 input_size=None, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(
            hidden_size,
            num_heads=num_heads,
            qkv_bias=True,
            use_rel_pos=use_rel_pos,
            rel_pos_zero_init=rel_pos_zero_init,
            input_size=input_size if window_size == 0 else (window_size, window_size, window_size),
        )

        self.input_size = input_size
        
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        # approx_gelu = lambda: nn.GELU(approximate="tanh")
        approx_gelu = lambda: nn.GELU() # for torch 1.7.1
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

        self.window_size = window_size

    def forward(self, x, c):
		    # c๊ฐ€ condition
		    # MLP๋ฅผ ํ†ตํ•ด์„œ shift์™€ scale factors๋ฅผ ์ƒ์„ฑ
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
                
        shortcut = x
        x = self.norm1(x)
        B = x.shape[0]
        x = x.reshape(B, self.input_size[0], self.input_size[1], self.input_size[2], -1)
        
        # Window partition
        if self.window_size > 0:
            X, Y, Z = x.shape[1], x.shape[2], x.shape[3]
            x, pad_xyz = window_partition(x, self.window_size)

        x = x.reshape(B, self.input_size[0] * self.input_size[1] * self.input_size[2], -1)
        x = modulate(x, shift_msa, scale_msa)
        x = x.reshape(B, self.input_size[0], self.input_size[1], self.input_size[2], -1)

        x = self.attn(x)
        
        # Reverse window partition
        if self.window_size > 0:
            x = window_unpartition(x, self.window_size, pad_xyz, (X, Y, Z))
        x = x.reshape(B, self.input_size[0] * self.input_size[1] * self.input_size[2], -1)

        x = shortcut + gate_msa.unsqueeze(1) * x
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        
        return x
  • self.adaLN_modulation(c).chunk(6, dim=1)
    • ํ†ตํ•ด์„œ, scale๊ณผ shift factors๋ฅผ ์ƒ์„ฑ
  • DiTBlock์ด ๋‹ค ๋๋‚˜๋ฉด, DiT class์—์„œ final layer๋กœ ๋“ค์–ด๊ฐ
    • out_channels: 3์ด๋ผ์„œ, coordinates ํ˜•ํƒœ๋กœ ๋‚˜์˜ด
    • class FinalLayer(nn.Module):
          """
          The final layer of DiT.
          """
          def __init__(self, hidden_size, patch_size, out_channels):
              super().__init__()
              self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
              self.linear = nn.Linear(hidden_size, patch_size * patch_size * patch_size * out_channels, bias=True)
              self.adaLN_modulation = nn.Sequential(
                  nn.SiLU(),
                  nn.Linear(hidden_size, 2 * hidden_size, bias=True)
              )
      
          def forward(self, x, c):
              shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
              x = modulate(self.norm_final(x), shift, scale)
              x = self.linear(x)
              return x # (p, p, p, c)
      
  • Loss ๊ณ„์‚ฐ์€, p_losses ํ•จ์ˆ˜์—์„œ
    • losses = ((noise - eps_recon)**2).mean(dim=list(range(1, len(data_start.shape))))
    • ์œ„์™€ ๊ฐ™์ด ๊ณ„์‚ฐ๋จ. (noise ์ •๋„๋ฅผ ์˜ˆ์ธก; MSE)

Loss functions ์ˆ˜์‹; ๋…ธ์ด์ฆˆ ๊ฐ„์˜ mse ์—ฐ์‚ฐ


๐ŸฒModel Inference

gen = model.gen_samples(x.shape, gpu, new_y_chain(gpu,y.shape[0],opt.num_classes), clip_denoised=False).detach().cpu()

→ Model class์—์„œ gen_samples ํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด์„œ ๋“ค์–ด๊ฐ

# gen_samples ํ•จ์ˆ˜
def gen_samples(self, shape, device, y, noise_fn=torch.randn,
                    clip_denoised=True, # False์ž„
                    keep_running=False):
    # GaussianDiffusion class            
    return self.diffusion.p_sample_loop(self._denoise, shape=shape, device=device, y=y, noise_fn=noise_fn,
                                        clip_denoised=clip_denoised,
                                        keep_running=keep_running)

# p_sample_loop ํ•จ์ˆ˜                                        
def p_sample_loop(self, denoise_fn, shape, device, y,
                      noise_fn=torch.randn, clip_denoised=True, keep_running=False):
    """
    Generate samples
    keep_running: True if we run 2 x num_timesteps, False if we just run num_timesteps

    """
		
    assert isinstance(shape, (tuple, list))
    #### noise_fn = torch.randn
    # img_t = ๋žœ๋ค ๋…ธ์ด์ฆˆ point cloud
    img_t = noise_fn(size=shape, dtype=torch.float, device=device)
    
    # keep_running = False
    # self.num_timesteps = 1000
    for t in reversed(range(0, self.num_timesteps if not keep_running else len(self.betas))):
        t_ = torch.empty(shape[0], dtype=torch.int64, device=device).fill_(t)
        img_t = self.p_sample(denoise_fn=denoise_fn, data=img_t,t=t_, noise_fn=noise_fn, y=y,
                              clip_denoised=clip_denoised, return_pred_xstart=False)

    assert img_t.shape == shape
    return img_t
  • Samplingํ•  ๋•Œ, timesteps = 1000์œผ๋กœ ์„ค์ •ํ•จ.
  • p_sample_loop ํ•จ์ˆ˜๋Š”, p_sample์„ timesteps ๋งŒํผ ๋ฐ˜๋ณต ์ˆ˜ํ–‰ํ•œ ๊ฒƒ.
def p_sample(self, denoise_fn, data, t, noise_fn, y, clip_denoised=False, return_pred_xstart=False):
    """
    Sample from the model
    """
    model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(denoise_fn, data=data, t=t, y=y, clip_denoised=clip_denoised,
                                                             return_pred_xstart=True)
    noise = noise_fn(size=data.shape, dtype=data.dtype, device=data.device)
    assert noise.shape == data.shape
    # no noise when t == 0
    nonzero_mask = torch.reshape(1 - (t == 0).float(), [data.shape[0]] + [1] * (len(data.shape) - 1))
		
		# Algorithm 2 -> Sampling
    sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
    assert sample.shape == pred_xstart.shape
    return (sample, pred_xstart) if return_pred_xstart else sample
    
    
# p_mean_variance ํ•จ์ˆ˜
def p_mean_variance(self, denoise_fn, data, t, y, clip_denoised: bool, return_pred_xstart: bool):    
    model_output = denoise_fn(data, t, y)
		
		## self.model_var_type: fixedsmall
    if self.model_var_type in ['fixedsmall', 'fixedlarge']:
        # below: only log_variance is used in the KL computations
        model_variance, model_log_variance = {
            # for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
            'fixedlarge': (self.betas.to(data.device),
                           torch.log(torch.cat([self.posterior_variance[1:2], self.betas[1:]])).to(data.device)),
            'fixedsmall': (self.posterior_variance.to(data.device), self.posterior_log_variance_clipped.to(data.device)),
        }[self.model_var_type]
        model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data)
        model_log_variance = self._extract(model_log_variance, t, data.shape) * torch.ones_like(data)
    else:
        raise NotImplementedError(self.model_var_type)

    if self.model_mean_type == 'eps':
        x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output)

        if clip_denoised: # False
            x_recon = torch.clamp(x_recon, -.5, .5)

        model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t)
    else:
        raise NotImplementedError(self.loss_type)

    assert model_mean.shape == x_recon.shape == data.shape
    assert model_variance.shape == model_log_variance.shape == data.shape
    if return_pred_xstart: # True์ž„.
        return model_mean, model_variance, model_log_variance, x_recon
    else:
        return model_mean, model_variance, model_log_variance

 

Posterior_variance ์ •์˜ (์˜ค๋ฅธ์ชฝ)

  • self.model_var_type: fixedsmall
  • posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.posterior_variance = posterior_variance # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain self.posterior_log_variance_clipped = torch.log(torch.max(posterior_variance, 1e-20 * torch.ones_like(posterior_variance)))
  • posterior_variance๋Š” DDPM์—์„œ ์ •์˜ํ•˜๋Š” variance ๊ฐ’
    • (์ข€ ๋” ๋””ํ…Œ์ผํ•˜๊ฒŒ๋Š”, beta ๊ฐ’์ด fixed ๋˜์–ด์žˆ์œผ๋ฏ€๋กœ ์œ„์™€ ๊ฐ™์ด ์ •์˜ํ•  ์ˆ˜ ์žˆ์Œ)
  • posterior_log_variance_clipped: variance์˜ log๊ฐ’
def _extract(a, t, x_shape):
    """
    Extract some coefficients at specified timesteps,
    then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    bs, = t.shape
    assert x_shape[0] == bs
    out = torch.gather(a, 0, t)
    assert out.shape == torch.Size([bs])
    return torch.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))
  • model_variance = self._extract(model_variance, t, data.shape) * torch.ones_like(data)
  • t๋ถ€๋ถ„์— ํ•ด๋‹นํ•˜๋Š” variance ์ถ”์ถœํ•˜๋Š” ๋ถ€๋ถ„ (์ธ ๊ฒƒ ๊ฐ™์Œ)

X0๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๋ฐฉ๋ฒ•; x0๋ฅผ ๋‚จ๊ธฐ๊ณ  ๋‹ค ์ขŒ๋ณ€์œผ๋กœ ๋„˜๊น€

  • self.model_mean_type: eps
  • x_recon = self._predict_xstart_from_eps(data, t=t, eps=model_output)
    • data: T ์‹œ์ ์˜ noised point cloud
    • model_output: T ์‹œ์  noised point cloud์— ๋Œ€ํ•ด์„œ, ์ž…ํ˜€์ง„ noise ์˜ˆ์ธก ๊ฐ’
    def _predict_xstart_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        return (
                self._extract(self.sqrt_recip_alphas_cumprod.to(x_t.device), t, x_t.shape) * x_t -
                self._extract(self.sqrt_recipm1_alphas_cumprod.to(x_t.device), t, x_t.shape) * eps
        )
    
    self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod).float()
    self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1).float()
    
    • _predict_xstart_from_eps: Xt์™€ noise๊ฐ’์„ ํ™œ์šฉํ•ด์„œ X0๋ฅผ ์œ ์ถ”ํ•˜๋Š” ๋‹จ๊ณ„ (์œ„์˜ ์ด๋ฏธ์ง€ ์ˆ˜์‹ ์ฐธ๊ณ )
      • Question: sqrt_recipm1_alphas_cumprod์˜ ๋ถ„๋ชจ๊ฐ€ alphas_cumprod - 1 ์ธ ๊ฒƒ์ด์ง€?
        • ๋ญ”๊ฐ€ ์•„๋ž˜ ์ˆ˜์‹์—์„œ x0๋ฅผ ๋‚จ๊ธฐ๊ณ  ๋‹ค ์ขŒ๋ณ€์œผ๋กœ ๋„˜๊ธฐ๋Š”๊ฒŒ ๋งž๋Š” ์ˆ˜์‹์ธ ๊ฒƒ ๊ฐ™์€๋ฐ..
        • ์•„..๋ถ„๋ชจ๋Š” ๊ทธ๋ƒฅ alphas_cumprod์ด๊ณ  (1/alphas_cumprod) - 1 ์ž„..ใ…‹ใ…‹ใ…‹ใ…‹
  • model_mean, _, _ = self.q_posterior_mean_variance(x_start=x_recon, x_t=data, t=t)
  • def q_posterior_mean_variance(self, x_start, x_t, t):
        """
        Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
        """
        assert x_start.shape == x_t.shape
        posterior_mean = (
                self._extract(self.posterior_mean_coef1.to(x_start.device), t, x_t.shape) * x_start +
                self._extract(self.posterior_mean_coef2.to(x_start.device), t, x_t.shape) * x_t
        )
        posterior_variance = self._extract(self.posterior_variance.to(x_start.device), t, x_t.shape)
        posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped.to(x_start.device), t, x_t.shape)
        assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
                x_start.shape[0])
        return posterior_mean, posterior_variance, posterior_log_variance_clipped
    
    ###########################
    self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
    self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)

Encoder์˜ mean ๊ฐ’ ์˜ˆ์ธก

  • model_mean = ์œ„์˜ ์ด๋ฏธ์ง€์˜ mean ๊ฐ’์— ํ•ด๋‹นํ•˜๋Š” ๋ถ€๋ถ„ (equation 7)
    • x_start = X0์— ํ•ด๋‹นํ•˜๋Š”๊ฐ’

 

Algorithm

  • ๋งˆ์ง€๋ง‰, DDPM์— ๋‚˜์˜ค๋Š” algorithm 2 ์ˆ˜์‹์— ๊ทผ๊ฑฐํ•˜์—ฌ์„œ ์•„๋ž˜์™€ ๊ฐ™์ด t-1๋ฒˆ์งธsampling์„ ์‹œ๋„ํ•จ.
    • sample = model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise

DDPM์˜ ์›๋ฆฌ๊ฐ€, decoder (denoising) ๊ณผ์ •์—์„œ encoder (diffusion)์˜ ํ‰๊ท ์„ ์˜ˆ์ธกํ•˜๋„๋ก ํ›ˆ๋ จ

  • → ์ฆ‰ decoder์˜ ๋ถ„ํฌ๋Š” encoder์™€ ๊ฐ™์•„์•ผ ํ•˜๋ฏ€๋กœ, encoder์˜ ๋ถ„ํฌ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ
  • ํ‰๊ท +๋ถ„์‚ฐ*๋…ธ์ด์ฆˆ ํ˜•ํƒœ๊ฐ€ ๋จ

๐Ÿ‰Inference ์š”์•ฝ

p_sample_loop [img_t: ๋žœ๋ค ๋…ธ์ด์ฆˆ point cloud]

→ p_sample [img_t์˜ ๋…ธ์ด์ฆˆ๋ฅผ T → T-1๋กœ ์ ์  ์ค„์—ฌ๋‚˜๊ฐ€๋Š” ๋‹จ๊ณ„; Algorithm2]

    → p_mean_variance [t์‹œ์ ์˜ q๋ถ„ํฌ์˜ Mean๊ณผ Variance ๊ณ„์‚ฐ]

        → _predict_xstart_from_eps [Noise์™€ X_t๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ, X0๋ฅผ ์˜ˆ์ธก]

        → q_posterior_mean_variance [X0์™€ X_t๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ, q๋ถ„ํฌ์˜ ํ‰๊ท ์„ ์˜ˆ์ธก]

 

๋ฐ˜์‘ํ˜•