Fine tuning Florence-2
En el post Florence-2 ya explicamos el modelo Florence-2 y vimos cómo usarlo. Así que en este post vamos a ver cómo hacerle fine tuning.
Fine tuning para Document VQA
Este fine tuning está basado en el post de Merve Noyan, Andres Marafioti y Piotr Skalski, Fine-tuning Florence-2 - Microsoft's Cutting-edge Vision Language Models, en el que explican que aunque este método es muy completo no permite hacer preguntas sobre documentos, así que hacen un reentreno con el dataset DocumentVQA
Dataset
En primer lugar descargamos el dataset. Dejo la variable dataset_percentage
por si no quieres daescargar todo.
from datasets import load_datasetdataset_percentage = 100data_train = load_dataset("HuggingFaceM4/DocumentVQA", split=f"train[:{dataset_percentage}%]")data_validation = load_dataset("HuggingFaceM4/DocumentVQA", split=f"validation[:{dataset_percentage}%]")data_test = load_dataset("HuggingFaceM4/DocumentVQA", split=f"test[:{dataset_percentage}%]")data_train, data_validation, data_test
(Dataset({features: ['questionId', 'question', 'question_types', 'image', 'docId', 'ucsf_document_id', 'ucsf_document_page_no', 'answers'],num_rows: 39463}),Dataset({features: ['questionId', 'question', 'question_types', 'image', 'docId', 'ucsf_document_id', 'ucsf_document_page_no', 'answers'],num_rows: 5349}),Dataset({features: ['questionId', 'question', 'question_types', 'image', 'docId', 'ucsf_document_id', 'ucsf_document_page_no', 'answers'],num_rows: 5188}))
Hacemos un subset del dataset por si quires hacer el entrenamiento más rápido, en mi caso uso el 100% de los datos
percentage = 1subset_data_train = data_train.select(range(int(len(data_train) * percentage)))subset_data_validation = data_validation.select(range(int(len(data_validation) * percentage)))subset_data_test = data_test.select(range(int(len(data_test) * percentage)))print(f"train dataset length: {len(subset_data_train)}, validation dataset length: {len(subset_data_validation)}, test dataset length: {len(subset_data_test)}")
train dataset length: 39463, validation dataset length: 5349, test dataset length: 5188
Instanciamos también el modelo
from transformers import AutoModelForCausalLM, AutoProcessorimport torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")checkpoints = 'microsoft/Florence-2-base-ft'model = AutoModelForCausalLM.from_pretrained(checkpoints, trust_remote_code=True).to(device)processor = AutoProcessor.from_pretrained(checkpoints, trust_remote_code=True)
Al igual que en post Florence-2 creamos una función para pedirle respuestas al modelo
from transformers import AutoModelForCausalLM, AutoProcessorimport torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")checkpoints = 'microsoft/Florence-2-base-ft'model = AutoModelForCausalLM.from_pretrained(checkpoints, trust_remote_code=True).to(device)processor = AutoProcessor.from_pretrained(checkpoints, trust_remote_code=True)def create_prompt(task_prompt, text_input=None):if text_input is None:prompt = task_promptelse:prompt = task_prompt + text_inputreturn prompt
from transformers import AutoModelForCausalLM, AutoProcessorimport torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")checkpoints = 'microsoft/Florence-2-base-ft'model = AutoModelForCausalLM.from_pretrained(checkpoints, trust_remote_code=True).to(device)processor = AutoProcessor.from_pretrained(checkpoints, trust_remote_code=True)def create_prompt(task_prompt, text_input=None):if text_input is None:prompt = task_promptelse:prompt = task_prompt + text_inputreturn promptdef generate_answer(task_prompt, text_input=None, image=None, device="cpu"):# Create promptprompt = create_prompt(task_prompt, text_input)# Ensure the image is in RGB modeif image.mode != "RGB":image = image.convert("RGB")# Get inputsinputs = processor(text=prompt, images=image, return_tensors="pt").to(device)# Get outputsgenerated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=1024,early_stopping=False,do_sample=False,num_beams=3,)# Decode the generated IDsgenerated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]# Post-process the generated textparsed_answer = processor.post_process_generation(generated_text,task=task_prompt,image_size=(image.width, image.height))return parsed_answer
Probamos el modelo con 3 documentos del dataset, con la tarea DocVQA
a ver si obtenemos algo
for idx in range(3):
print(generate_answer(task_prompt="<DocVQA>", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))
display(data_train[idx]['image'].resize([350, 350]))
for idx in range(3):
print(generate_answer(task_prompt="DocVQA", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))
display(data_train[idx]['image'].resize([350, 350]))
Vemos que las respuestas no son buenas
Probamos ahora con la tarea OCR
for idx in range(3):
print(generate_answer(task_prompt="<OCR>", image=data_train[idx]['image'], device=model.device))
display(data_train[idx]['image'].resize([350, 350]))
Obtenemos el texto de los documentos, pero no de qué tratan los documentos
Por último probamos con las tareas CAPTION
for idx in range(3):
print(generate_answer(task_prompt="<CAPTION>", image=data_train[idx]['image'], device=model.device))
print(generate_answer(task_prompt="<DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))
print(generate_answer(task_prompt="<MORE_DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))
display(data_train[idx]['image'].resize([350, 350]))
Tampoco nos valen estas respuestas, así que vamos a hacer el fine tuning
Fine tuning
Primero creamos un dataset de Pytorch
from transformers import AutoModelForCausalLM, AutoProcessorimport torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")checkpoints = 'microsoft/Florence-2-base-ft'model = AutoModelForCausalLM.from_pretrained(checkpoints, trust_remote_code=True).to(device)processor = AutoProcessor.from_pretrained(checkpoints, trust_remote_code=True)def create_prompt(task_prompt, text_input=None):if text_input is None:prompt = task_promptelse:prompt = task_prompt + text_inputreturn promptdef generate_answer(task_prompt, text_input=None, image=None, device="cpu"):# Create promptprompt = create_prompt(task_prompt, text_input)# Ensure the image is in RGB modeif image.mode != "RGB":image = image.convert("RGB")# Get inputsinputs = processor(text=prompt, images=image, return_tensors="pt").to(device)# Get outputsgenerated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=1024,early_stopping=False,do_sample=False,num_beams=3,)# Decode the generated IDsgenerated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]# Post-process the generated textparsed_answer = processor.post_process_generation(generated_text,task=task_prompt,image_size=(image.width, image.height))return parsed_answerfor idx in range(3):print(generate_answer(task_prompt="<DocVQA>", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="DocVQA", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="<OCR>", image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="<CAPTION>", image=data_train[idx]['image'], device=model.device))print(generate_answer(task_prompt="<DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))print(generate_answer(task_prompt="<MORE_DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))from torch.utils.data import Datasetclass DocVQADataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):example = self.data[idx]question = "<DocVQA>" + example['question']first_answer = example['answers'][0]image = example['image']if image.mode != "RGB":image = image.convert("RGB")return question, first_answer, image
from transformers import AutoModelForCausalLM, AutoProcessorimport torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")checkpoints = 'microsoft/Florence-2-base-ft'model = AutoModelForCausalLM.from_pretrained(checkpoints, trust_remote_code=True).to(device)processor = AutoProcessor.from_pretrained(checkpoints, trust_remote_code=True)def create_prompt(task_prompt, text_input=None):if text_input is None:prompt = task_promptelse:prompt = task_prompt + text_inputreturn promptdef generate_answer(task_prompt, text_input=None, image=None, device="cpu"):# Create promptprompt = create_prompt(task_prompt, text_input)# Ensure the image is in RGB modeif image.mode != "RGB":image = image.convert("RGB")# Get inputsinputs = processor(text=prompt, images=image, return_tensors="pt").to(device)# Get outputsgenerated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=1024,early_stopping=False,do_sample=False,num_beams=3,)# Decode the generated IDsgenerated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]# Post-process the generated textparsed_answer = processor.post_process_generation(generated_text,task=task_prompt,image_size=(image.width, image.height))return parsed_answerfor idx in range(3):print(generate_answer(task_prompt="<DocVQA>", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="DocVQA", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="<OCR>", image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="<CAPTION>", image=data_train[idx]['image'], device=model.device))print(generate_answer(task_prompt="<DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))print(generate_answer(task_prompt="<MORE_DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))from torch.utils.data import Datasetclass DocVQADataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):example = self.data[idx]question = "<DocVQA>" + example['question']first_answer = example['answers'][0]image = example['image']if image.mode != "RGB":image = image.convert("RGB")return question, first_answer, imagetrain_dataset = DocVQADataset(subset_data_train)val_dataset = DocVQADataset(subset_data_validation)
Vamos a verlo
from transformers import AutoModelForCausalLM, AutoProcessorimport torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")checkpoints = 'microsoft/Florence-2-base-ft'model = AutoModelForCausalLM.from_pretrained(checkpoints, trust_remote_code=True).to(device)processor = AutoProcessor.from_pretrained(checkpoints, trust_remote_code=True)def create_prompt(task_prompt, text_input=None):if text_input is None:prompt = task_promptelse:prompt = task_prompt + text_inputreturn promptdef generate_answer(task_prompt, text_input=None, image=None, device="cpu"):# Create promptprompt = create_prompt(task_prompt, text_input)# Ensure the image is in RGB modeif image.mode != "RGB":image = image.convert("RGB")# Get inputsinputs = processor(text=prompt, images=image, return_tensors="pt").to(device)# Get outputsgenerated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=1024,early_stopping=False,do_sample=False,num_beams=3,)# Decode the generated IDsgenerated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]# Post-process the generated textparsed_answer = processor.post_process_generation(generated_text,task=task_prompt,image_size=(image.width, image.height))return parsed_answerfor idx in range(3):print(generate_answer(task_prompt="<DocVQA>", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="DocVQA", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="<OCR>", image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="<CAPTION>", image=data_train[idx]['image'], device=model.device))print(generate_answer(task_prompt="<DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))print(generate_answer(task_prompt="<MORE_DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))from torch.utils.data import Datasetclass DocVQADataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):example = self.data[idx]question = "<DocVQA>" + example['question']first_answer = example['answers'][0]image = example['image']if image.mode != "RGB":image = image.convert("RGB")return question, first_answer, imagetrain_dataset = DocVQADataset(subset_data_train)val_dataset = DocVQADataset(subset_data_validation)train_dataset[0]
{'<DocVQA>': 'docvQA'}{'DocVQA': 'unanswerable'}{'<OCR>': 'ConfidentialDATE:11/8/18RJT FR APPROVALBUBJECT: Rl gdasPROPOSED RELEASE DATE:for responseFOR RELEASE TO!CONTRACT: P. CARTERROUTE TO!NameIntiifnPeggy CarterAce11/fesMura PayneDavid Fishhel037Tom Gisis Com-Diane BarrowsEd BlackmerTow KuckerReturn to Peggy Carter, PR, 16 Raynolds BuildingLLS. 2015Source: https://www.industrydocuments.ucsf.edu/docs/xnbl0037'}{'<CAPTION>': 'A certificate is stamped with the date of 18/18.'}{'<DETAILED_CAPTION>': 'In this image we can see a paper with some text on it.'}{'<MORE_DETAILED_CAPTION>': 'A letter is written in black ink on a white paper. The letters are written in a cursive language. The letter is addressed to peggy carter. '}('<DocVQA>what is the date mentioned in this letter?','1/8/93',<PIL.Image.Image image mode=RGB size=1695x2025>)
data_train[0]
{'questionId': 337,'question': 'what is the date mentioned in this letter?','question_types': ['handwritten', 'form'],'image': <PIL.PngImagePlugin.PngImageFile image mode=L size=1695x2025>,'docId': 279,'ucsf_document_id': 'xnbl0037','ucsf_document_page_no': '1','answers': ['1/8/93']}
Creamos un dataloader
import osfrom torch.utils.data import DataLoaderfrom tqdm import tqdmfrom transformers import (AdamW, AutoProcessor, get_scheduler)def collate_fn(batch):questions, answers, images = zip(*batch)inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(device)return inputs, answers# Create DataLoaderbatch_size = 8num_workers = 0train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers)
Vamos a ver una muestra
import osfrom torch.utils.data import DataLoaderfrom tqdm import tqdmfrom transformers import (AdamW, AutoProcessor, get_scheduler)def collate_fn(batch):questions, answers, images = zip(*batch)inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(device)return inputs, answers# Create DataLoaderbatch_size = 8num_workers = 0train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers)sample = next(iter(train_loader))
import osfrom torch.utils.data import DataLoaderfrom tqdm import tqdmfrom transformers import (AdamW, AutoProcessor, get_scheduler)def collate_fn(batch):questions, answers, images = zip(*batch)inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(device)return inputs, answers# Create DataLoaderbatch_size = 8num_workers = 0train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers)sample = next(iter(train_loader))sample
({'input_ids': tensor([[ 0, 41552, 42291, 846, 1864, 250, 15698, 12375, 16, 5,3383, 9, 331, 9, 2042, 116, 2, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 5,11968, 196, 205, 22922, 346, 17487, 2, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 5,1229, 13, 403, 690, 116, 2, 1, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 5,5480, 1280, 116, 2, 1, 1, 1, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 12196, 16, 5,1842, 346, 13, 20, 4680, 41828, 42237, 8, 30147, 17487,2, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 560, 61, 675,473, 42, 1013, 266, 9943, 7, 116, 2, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 12196, 16, 5,1280, 9, 39432, 642, 6228, 2394, 2801, 11, 5, 576,266, 17487, 2],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 1982,11, 5, 6655, 2325, 23, 5, 299, 235, 9, 5,3780, 116, 2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],...'97.00','123','1 January 1979 - 31 December 1979','$2,720.14','GPI'))
La muestra en crudo es mucha información, así que vamos a ver la longitud de la muestra
len(sample)
2
Obtenemos una longitud de 2 porque tenemos la entrada al modelo y la respuesta
sample_inputs = sample[0]sample_answers = sample[1]
Vemos la entrada
sample_inputs = sample[0]sample_answers = sample[1]sample_inputs
{'input_ids': tensor([[ 0, 41552, 42291, 846, 1864, 250, 15698, 12375, 16, 5,3383, 9, 331, 9, 2042, 116, 2, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 5,11968, 196, 205, 22922, 346, 17487, 2, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 5,1229, 13, 403, 690, 116, 2, 1, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 5,5480, 1280, 116, 2, 1, 1, 1, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 12196, 16, 5,1842, 346, 13, 20, 4680, 41828, 42237, 8, 30147, 17487,2, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 560, 61, 675,473, 42, 1013, 266, 9943, 7, 116, 2, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 12196, 16, 5,1280, 9, 39432, 642, 6228, 2394, 2801, 11, 5, 576,266, 17487, 2],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 1982,11, 5, 6655, 2325, 23, 5, 299, 235, 9, 5,3780, 116, 2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],...[ 2.6400, 2.6400, 2.6400, ..., 1.3502, 0.7925, 1.3502],[ 2.6400, 2.6400, 2.6400, ..., 0.9319, 1.4025, 0.8448],[ 2.6400, 2.6400, 2.6400, ..., 1.0365, 1.2282, 0.8099]]]])}
La entrada en crudo tambien tiene demasiada información, así que vamos a ver las keys
sample_inputs.keys()
dict_keys(['input_ids', 'attention_mask', 'pixel_values'])
Como vemos tenemos los input_ids
y los attention_mask
que corresponden al texto de entrada y los pixel_values
que corresponden a la imagen. Vamos a ver la dimención de cada uno
sample_inputs['input_ids'].shape, sample_inputs['attention_mask'].shape, sample_inputs['pixel_values'].shape
(torch.Size([8, 23]), torch.Size([8, 23]), torch.Size([8, 3, 768, 768]))
En todos hay 8 elementos, porque al crear el dataloader pusimos un batch size de 8. En los input_ids
y attention_mask
cada elemento tiene 28 tokens y en los pixel_values
cada elemento tiene 3 canales, 768 píxeles de alto y 768 píxeles de ancho
Vamos ahora a ver las respuestas
sample_answers
('JAMES A. RHODES','1-800-992-3284','$50,000','97.00','123','1 January 1979 - 31 December 1979','$2,720.14','GPI')
Hemos obtenido 8 respuestas, por lo mismo que antes, porque al crear el dataloader pusimos un batch size de 8
len(sample_answers)
8
Creamos una función para hacer el fine tuning
def train_model(train_loader, val_loader, model, processor, epochs=10, lr=1e-6):optimizer = AdamW(model.parameters(), lr=lr)num_training_steps = epochs * len(train_loader)lr_scheduler = get_scheduler(name="linear",optimizer=optimizer,num_warmup_steps=0,num_training_steps=num_training_steps,)for epoch in range(epochs):# Training phaseprint(f" Training Epoch {epoch + 1}/{epochs}")model.train()train_loss = 0i = -1for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}"):i += 1inputs, answers = batchinput_ids = inputs["input_ids"]pixel_values = inputs["pixel_values"]labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)loss = outputs.lossloss.backward()optimizer.step()lr_scheduler.step()optimizer.zero_grad()train_loss += loss.item()avg_train_loss = train_loss / len(train_loader)print(f"Average Training Loss: {avg_train_loss}")# Validation phasemodel.eval()val_loss = 0with torch.no_grad():for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{epochs}"):inputs, answers = batchinput_ids = inputs["input_ids"]pixel_values = inputs["pixel_values"]labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)loss = outputs.lossval_loss += loss.item()avg_val_loss = val_loss / len(val_loader)print(f"Average Validation Loss: {avg_val_loss}")
Entrenamos
train_model(train_loader, val_loader, model, processor, epochs=3, lr=1e-6)
Probar el modelo fine tuned
Probamos ahora el modelo en unos cuantos documentos del conjunto de test
for idx in range(3):
print(generate_answer(task_prompt="<DocVQA>", text_input='What do you see in this image?', image=data_test[idx]['image'], device=model.device))
display(data_test[idx]['image'].resize([350, 350]))
Vemos que nos da información
Vamos ahora a vovler a probar sobre el conjunto de test, para comparar con lo que salía antes de entenar
for idx in range(3):
print(generate_answer(task_prompt="<DocVQA>", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))
display(data_train[idx]['image'].resize([350, 350]))
No da muy buenos resultados, pero solo hemos entrenado 3 epoch. Aunque se podría mejorar entrenando más, lo que se puede ver es que cuando antes usábamos el tag de tarea <DocVQA>
no obteníamos respuesta, pero ahora sí.