diff --git a/diffusion_pen/diff_unet.py b/diffusion_pen/diff_unet.py index ae6a322..22ed896 100644 --- a/diffusion_pen/diff_unet.py +++ b/diffusion_pen/diff_unet.py @@ -203,7 +203,7 @@ def forward(self, x, context=None, mask=None): k = self.to_k(context) v = self.to_v(context) - mask = None #torch.ones(1, 8192).bool().cuda('cuda:6') + mask = None q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) diff --git a/main.py b/main.py index e116905..99e4a79 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import os +import shutil import torch from typing import List from model_inference import main_sample @@ -83,8 +84,17 @@ async def generate_handwriting(payload: HandwritingRequest): model_pipeline=pipeline, text_list=texts, style_refs=style_image_paths, - uname=uname + uname=uname, + device=torch.device("cuda" if torch.cuda.is_available() else "cpu") ) + # check and remove the tmp/uname folder + base_tmp = "/tmp" + user_tmp = os.path.join(base_tmp, uname) + + if os.path.isdir(user_tmp): + shutil.rmtree(user_tmp) + + os.makedirs(user_tmp, exist_ok=True) return { "message": "Handwriting generation completed.", diff --git a/model_inference.py b/model_inference.py index c6e68ad..4b8d63f 100644 --- a/model_inference.py +++ b/model_inference.py @@ -119,13 +119,9 @@ def main_sample( text_list : List[str], style_refs : List[str], uname : str, + device : str ): - device = 'cuda' - if not torch.cuda.is_available(): - print("CUDA is needed.") - return - # tokenizer and text encoder tokenizer = model_pipeline.text_tokenizer#CanineTokenizer.from_pretrained("google/canine-c") text_encoder = model_pipeline.text_encoder#CanineModel.from_pretrained("google/canine-c").to(device) @@ -175,7 +171,7 @@ def main_sample( device=device, img_h=64, img_w=256, - steps=100, + steps= 30, ) image = image_tensor[0].cpu() if image.size(0) == 4: # latent decoded returns 3 channels diff --git a/model_pipeline.py b/model_pipeline.py index a559c01..db2d4a9 100644 --- a/model_pipeline.py +++ b/model_pipeline.py @@ -35,6 +35,7 @@ def __init__(self, vocab_size=95, # unused internally; placeholder text_encoder=self.text_encoder, args=SimpleArgs(interpolation=False, mix_rate=None), + ) self.unet = UNetModel(**unet_cfg).to(device)