diff --git a/scripts/run_multimodal.py b/scripts/run_multimodal.py index 231e340..d040dac 100644 --- a/scripts/run_multimodal.py +++ b/scripts/run_multimodal.py @@ -99,7 +99,7 @@ def main(_): for key in image_paths: try: image[key] = Image.open(image_paths[key]) # Open local file - image[key].show() + # image[key].show() except IOError as e: print(f"Error loading image: {e}") exit() @@ -113,8 +113,7 @@ def main(_): device = torch.device(_DEVICE.value) with _set_default_tensor_type(model_config.get_dtype()): model = gemma3_model.Gemma3ForMultimodalLM(model_config) - model.load_state_dict(torch.load(_CKPT.value)['model_state_dict']) - # model.load_weights(_CKPT.value) + model.load_weights(_CKPT.value) model = model.to(device).eval() print('Model loading done')