utils.py 3.86 KiB
import torch
import torch.nn.functional as F
import numpy as np
from scipy import interpolate
# Get Normal from Zhang Songyan
def get_normal(target_disp: torch.Tensor):
    edge_kernel_x = torch.from_numpy(np.array([[-3, 0, 3], [-10, 0, 10], [-3, 0, 3]])).type_as(target_disp)
    edge_kernel_y = torch.from_numpy(np.array([[-3, -10, -3], [0, 0, 0], [3, 10, 3]])).type_as(target_disp)
    sobel_kernel = torch.cat((edge_kernel_x.view(1, 1, 3, 3), edge_kernel_y.view(1, 1, 3, 3)), dim=0)
    sobel_kernel.requires_grad = False
    grad_depth = torch.nn.functional.conv2d(target_disp, sobel_kernel, padding=1) * -1.
    N, C, H, W = grad_depth.shape
    norm = torch.cat((grad_depth, torch.ones(N, 1, H, W).to(target_disp.device)), dim=1)
    target_normal = F.normalize(norm, dim=1)
    return target_normal
class InputPadder:
    """Pads images such that dimensions are divisible by 8"""
    def __init__(self, dims, mode="sintel", divis_by=8):
        self.ht, self.wd = dims[-2:]
        pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
        pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
        if mode == "sintel":
            self._pad = [
                pad_wd // 2,
                pad_wd - pad_wd // 2,
                pad_ht // 2,
                pad_ht - pad_ht // 2,
        else:
            self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
    def pad(self, *inputs):
        assert all((x.ndim == 4) for x in inputs)
        return [F.pad(x, self._pad, mode="replicate") for x in inputs]
    def unpad(self, x):
        assert x.ndim == 4
        ht, wd = x.shape[-2:]
        c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
        return x[..., c[0] : c[1], c[2] : c[3]]
def forward_interpolate(flow):
    flow = flow.detach().cpu().numpy()
    dx, dy = flow[0], flow[1]
    ht, wd = dx.shape
    x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
    x1 = x0 + dx
    y1 = y0 + dy
    x1 = x1.reshape(-1)
    y1 = y1.reshape(-1)
    dx = dx.reshape(-1)
    dy = dy.reshape(-1)
    valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
    x1 = x1[valid]
    y1 = y1[valid]
    dx = dx[valid]
    dy = dy[valid]
    flow_x = interpolate.griddata(
        (x1, y1), dx, (x0, y0), method="nearest", fill_value=0