mirror of
https://github.com/openai/shap-e.git
synced 2026-02-02 17:59:50 +08:00
example of creating meshes
This commit is contained in:
@@ -80,6 +80,21 @@
|
||||
" images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n",
|
||||
" display(gif_widget(images))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "85a4dce4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Example of saving the latents as meshes.\n",
|
||||
"from shap_e.util.notebooks import decode_latent_mesh\n",
|
||||
"\n",
|
||||
"for i, latent in enumerate(latents):\n",
|
||||
" with open(f'example_mesh_{i}.ply', 'wb') as f:\n",
|
||||
" decode_latent_mesh(xm, latent).tri_mesh().write_ply(f)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@@ -256,6 +256,16 @@ def render_views_from_stf(
|
||||
if output_srgb:
|
||||
tf_out.channels = _convert_srgb_to_linear(tf_out.channels)
|
||||
|
||||
# Make sure the raw meshes have colors.
|
||||
with torch.autocast(device_type, enabled=False):
|
||||
textures = tf_out.channels.float()
|
||||
assert len(textures.shape) == 3 and textures.shape[-1] == len(
|
||||
texture_channels
|
||||
), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
|
||||
for m, texture in zip(raw_meshes, textures):
|
||||
texture = texture[: len(m.verts)]
|
||||
m.vertex_channels = {name: ch for name, ch in zip(texture_channels, texture.unbind(-1))}
|
||||
|
||||
args = dict(
|
||||
options=options,
|
||||
texture_channels=texture_channels,
|
||||
@@ -315,6 +325,8 @@ def _render_with_pytorch3d(
|
||||
raw_meshes: List[TorchMesh],
|
||||
tf_out: AttrDict,
|
||||
):
|
||||
_ = tf_out
|
||||
|
||||
# Lazy import because pytorch3d is installed lazily.
|
||||
from shap_e.rendering.pytorch3d_util import (
|
||||
blender_uniform_lights,
|
||||
@@ -328,14 +340,6 @@ def _render_with_pytorch3d(
|
||||
device_type = device.type
|
||||
|
||||
with torch.autocast(device_type, enabled=False):
|
||||
textures = tf_out.channels.float()
|
||||
assert len(textures.shape) == 3 and textures.shape[-1] == len(
|
||||
texture_channels
|
||||
), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
|
||||
for m, texture in zip(raw_meshes, textures):
|
||||
texture = texture[: len(m.verts)]
|
||||
m.vertex_channels = {name: ch for name, ch in zip(texture_channels, texture.unbind(-1))}
|
||||
|
||||
meshes = convert_meshes(raw_meshes)
|
||||
|
||||
lights = blender_uniform_lights(
|
||||
|
||||
@@ -9,6 +9,7 @@ from PIL import Image
|
||||
|
||||
from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera
|
||||
from shap_e.models.transmitter.base import Transmitter, VectorDecoder
|
||||
from shap_e.rendering.torch_mesh import TorchMesh
|
||||
from shap_e.util.collections import AttrDict
|
||||
|
||||
|
||||
@@ -60,6 +61,21 @@ def decode_latent_images(
|
||||
return [Image.fromarray(x) for x in arr]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_latent_mesh(
|
||||
xm: Union[Transmitter, VectorDecoder],
|
||||
latent: torch.Tensor,
|
||||
) -> TorchMesh:
|
||||
decoded = xm.renderer.render_views(
|
||||
AttrDict(cameras=create_pan_cameras(2, latent.device)), # lowest resolution possible
|
||||
params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
|
||||
latent[None]
|
||||
),
|
||||
options=AttrDict(rendering_mode="stf", render_with_direction=False),
|
||||
)
|
||||
return decoded.raw_meshes[0]
|
||||
|
||||
|
||||
def gif_widget(images):
|
||||
writer = io.BytesIO()
|
||||
images[0].save(
|
||||
|
||||
Reference in New Issue
Block a user