diff --git a/gradio-app/app.py b/gradio-app/app.py index d50843f..2889b4c 100644 --- a/gradio-app/app.py +++ b/gradio-app/app.py @@ -279,6 +279,17 @@ def show_global_error(message: str): def hide_global_error(): return gr.update(value="", visible=False) + +def unload_model(): + global g_model, g_processor, g_quant + print("Unloading model and clearing memory...") + g_model = None + g_processor = None + g_quant = None + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + def load_model(quant: str, status: gr.HTML | None = None): """Load the model and processor if not already loaded.""" @@ -313,7 +324,7 @@ def load_model(quant: str, status: gr.HTML | None = None): try: if quant == "bf16": - g_model = LlavaForConditionalGeneration.from_pretrained(MODEL_PATH, torch_dtype="bfloat16", device_map=0) + g_model = LlavaForConditionalGeneration.from_pretrained(MODEL_PATH, dtype="bfloat16", device_map=0) assert isinstance(g_model, LlavaForConditionalGeneration), f"Expected LlavaForConditionalGeneration, got {type(g_model)}" if _HAS_LIGER: try: @@ -338,7 +349,7 @@ def load_model(quant: str, status: gr.HTML | None = None): else: raise ValueError(f"Unknown quantization type: {quant}") - g_model = LlavaForConditionalGeneration.from_pretrained(MODEL_PATH, torch_dtype="auto", device_map=0, quantization_config=qnt_config) + g_model = LlavaForConditionalGeneration.from_pretrained(MODEL_PATH, dtype="auto", device_map=0, quantization_config=qnt_config) assert isinstance(g_model, LlavaForConditionalGeneration), f"Expected LlavaForConditionalGeneration, got {type(g_model)}" g_model.eval() @@ -426,7 +437,7 @@ def print_system_info(): @torch.no_grad() -def chat_joycaption(input_image: Image.Image, prompt: str, temperature: float, top_p: float, max_new_tokens: int, quant: str) -> Generator[dict, None, None]: +def chat_joycaption(input_image: Image.Image, prompt: str, temperature: float, top_p: float, max_new_tokens: int, quant: str, unload_after: bool) -> Generator[dict, None, None]: # Hide any previous global errors yield {global_error: hide_global_error()} @@ -491,6 +502,11 @@ def chat_joycaption(input_image: Image.Image, prompt: str, temperature: float, t t.join() yield {single_status_output: gr.update(value="Captioning complete!")} + + if unload_after: + unload_model() + yield {single_status_output: format_info("Model unloaded from memory.")} + except Exception as e: error_msg = f"Error during generation: {str(e)}" print(error_msg) @@ -752,8 +768,10 @@ def collate_fn(batch: list[tuple[Path, Image.Image, str, str] | None], *, proces with gr.Column(scale=1): initial_single_prompt = build_prompt(caption_type.value, caption_length.value, extra_options.value, name_input.value) prompt_box_single = gr.Textbox(lines=4, label="Confirm or Edit Prompt", value=initial_single_prompt, interactive=True, elem_id="single_prompt_box") + unload_after_gen = gr.Checkbox(label="Unload model from RAM/VRAM after generation",value=False) run_button_single = gr.Button("Caption", variant="primary") output_caption_single = gr.Textbox(label="Generated Caption", lines=8, interactive=True, elem_id="single_output_box") + # Batch Processing Tab with gr.TabItem("Batch Processing", id="batch_tab"): @@ -810,7 +828,7 @@ def collate_fn(batch: list[tuple[Path, Image.Image, str, str] | None], *, proces # Handle single image captioning run_button_single.click( chat_joycaption, - inputs=[input_image_single, prompt_box_single, temperature_slider, top_p_slider, max_tokens_slider, model_quantization], + inputs=[input_image_single, prompt_box_single, temperature_slider, top_p_slider, max_tokens_slider, model_quantization, unload_after_gen], outputs=[single_status_output, output_caption_single, global_error], ) @@ -824,4 +842,4 @@ def collate_fn(batch: list[tuple[Path, Image.Image, str, str] | None], *, proces if __name__ == "__main__": print_system_info() - demo.launch() \ No newline at end of file + demo.launch()