Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def from_pretrained(cls, model_uri: str, **kwargs):
from safetensors.torch import load_file
from .ae_nn import AutoEncoder

base = pathlib.Path(huggingface_hub.snapshot_download(model_uri))
try:
base = pathlib.Path(huggingface_hub.snapshot_download(model_uri))
except Exception:
base = pathlib.Path(model_uri)

enc_cfg = OmegaConf.load(base / "encoder_conf.yml").model
dec_cfg = OmegaConf.load(base / "decoder_conf.yml").model
Expand Down
4 changes: 3 additions & 1 deletion src/world_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
"""
model_uri: HF URI or local folder containing model.safetensors and config.yaml
quant: None | w8a8 | nvfp4
model_config_overrides: Dict to override model config values
"""
self.device, self.dtype = device, dtype

Expand All @@ -49,7 +50,8 @@ def __init__(

self.prompt_encoder = None
if self.model_cfg.prompt_conditioning is not None:
self.prompt_encoder = PromptEncoder("google/umt5-xl", dtype=dtype).to(device).eval() # TODO: dont hardcode
pe_uri = getattr(self.model_cfg, "prompt_encoder_uri", "google/umt5-xl")
self.prompt_encoder = PromptEncoder(pe_uri, dtype=dtype).to(device).eval()

self.model = WorldModel.from_pretrained(model_uri, cfg=self.model_cfg).to(device=device, dtype=dtype).eval()
apply_inference_patches(self.model)
Expand Down