Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ecomfruitai/modeling/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.amp import autocast, GradScaler
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_generation_with_models(models, step):
with torch.no_grad():
from .predict import generate_image
test_prompt = "red apple, whole fruit, realistic photo"
generated = generate_image(test_prompt, models, num_inference_steps=20)
generate_image(test_prompt, models, num_inference_steps=20)
print("Generation test successful!")
except Exception as e:
print(f"Generation test failed: {e}")