加载文件时发生错误。请再试一次。
加载文件时发生错误。请再试一次。
加载文件时发生错误。请再试一次。
-
由 huangzhuofei 创作于fa5a2470
import torch
import torch.nn.functional as F
import numpy as np
from scipy import interpolate
# Get Normal from Zhang Songyan
def get_normal(target_disp: torch.Tensor):
edge_kernel_x = torch.from_numpy(np.array([[-3, 0, 3], [-10, 0, 10], [-3, 0, 3]])).type_as(target_disp)
edge_kernel_y = torch.from_numpy(np.array([[-3, -10, -3], [0, 0, 0], [3, 10, 3]])).type_as(target_disp)
sobel_kernel = torch.cat((edge_kernel_x.view(1, 1, 3, 3), edge_kernel_y.view(1, 1, 3, 3)), dim=0)
sobel_kernel.requires_grad = False
grad_depth = torch.nn.functional.conv2d(target_disp, sobel_kernel, padding=1) * -1.
N, C, H, W = grad_depth.shape
norm = torch.cat((grad_depth, torch.ones(N, 1, H, W).to(target_disp.device)), dim=1)
target_normal = F.normalize(norm, dim=1)
return target_normal
class InputPadder:
"""Pads images such that dimensions are divisible by 8"""
def __init__(self, dims, mode="sintel", divis_by=8):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
if mode == "sintel":
self._pad = [
pad_wd // 2,
pad_wd - pad_wd // 2,
pad_ht // 2,
pad_ht - pad_ht // 2,
]
else:
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
def pad(self, *inputs):
assert all((x.ndim == 4) for x in inputs)
return [F.pad(x, self._pad, mode="replicate") for x in inputs]
def unpad(self, x):
assert x.ndim == 4
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0] : c[1], c[2] : c[3]]
def forward_interpolate(flow):
flow = flow.detach().cpu().numpy()
dx, dy = flow[0], flow[1]
ht, wd = dx.shape
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
x1 = x0 + dx
y1 = y0 + dy
x1 = x1.reshape(-1)
y1 = y1.reshape(-1)
dx = dx.reshape(-1)
dy = dy.reshape(-1)
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
x1 = x1[valid]
y1 = y1[valid]
dx = dx[valid]
dy = dy[valid]
flow_x = interpolate.griddata(
(x1, y1), dx, (x0, y0), method="nearest", fill_value=0
)