From ae91b63fe9dcbeb0f1670ed6b1527018ba4446b2 Mon Sep 17 00:00:00 2001 From: "Carlos G." <78917769+carlosguzu@users.noreply.github.com> Date: Wed, 25 Jun 2025 13:38:59 -0500 Subject: [PATCH] Fix AMP import and remove unused variable --- ecomfruitai/modeling/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ecomfruitai/modeling/train.py b/ecomfruitai/modeling/train.py index 50e1c2c..9160714 100644 --- a/ecomfruitai/modeling/train.py +++ b/ecomfruitai/modeling/train.py @@ -2,7 +2,7 @@ import torch import torch.nn.functional as F from torch.optim import AdamW -from torch.amp import autocast, GradScaler +from torch.cuda.amp import autocast, GradScaler from tqdm import tqdm from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL from transformers import CLIPTextModel, CLIPTokenizer @@ -146,7 +146,7 @@ def test_generation_with_models(models, step): with torch.no_grad(): from .predict import generate_image test_prompt = "red apple, whole fruit, realistic photo" - generated = generate_image(test_prompt, models, num_inference_steps=20) + generate_image(test_prompt, models, num_inference_steps=20) print("Generation test successful!") except Exception as e: print(f"Generation test failed: {e}")