Fine tuning Florence-2
In the post Florence-2 we already explained the Florence-2 model and how to use it. So in this post we are going to see how to fine tune it.
This notebook has been automatically translated to make it accessible to more people, please let me know if you see any typos.
Fine tuning for Document VQA
This fine tuning is based on the post by Merve Noyan, Andres Marafioti and Piotr Skalski, Fine-tuning Florence-2 - Microsoft's Cutting-edge Vision Language Models, in which they explain that although this method is very complete it does not allow queries on documents, so they do a re-training with the DocumentVQA dataset.
Dataset
First we download the dataset. I leave the variable dataset_percentage
in case you don't want to download everything.
from datasets import load_datasetdataset_percentage = 100data_train = load_dataset("HuggingFaceM4/DocumentVQA", split=f"train[:{dataset_percentage}%]")data_validation = load_dataset("HuggingFaceM4/DocumentVQA", split=f"validation[:{dataset_percentage}%]")data_test = load_dataset("HuggingFaceM4/DocumentVQA", split=f"test[:{dataset_percentage}%]")data_train, data_validation, data_test
(Dataset({features: ['questionId', 'question', 'question_types', 'image', 'docId', 'ucsf_document_id', 'ucsf_document_page_no', 'answers'],num_rows: 39463}),Dataset({features: ['questionId', 'question', 'question_types', 'image', 'docId', 'ucsf_document_id', 'ucsf_document_page_no', 'answers'],num_rows: 5349}),Dataset({features: ['questionId', 'question', 'question_types', 'image', 'docId', 'ucsf_document_id', 'ucsf_document_page_no', 'answers'],num_rows: 5188}))
We make a subset of the dataset in case you want to make the training faster, in my case I use 100% of the data.
percentage = 1subset_data_train = data_train.select(range(int(len(data_train) * percentage)))subset_data_validation = data_validation.select(range(int(len(data_validation) * percentage)))subset_data_test = data_test.select(range(int(len(data_test) * percentage)))print(f"train dataset length: {len(subset_data_train)}, validation dataset length: {len(subset_data_validation)}, test dataset length: {len(subset_data_test)}")
train dataset length: 39463, validation dataset length: 5349, test dataset length: 5188
We also instantiate the model
from transformers import AutoModelForCausalLM, AutoProcessorimport torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")checkpoints = 'microsoft/Florence-2-base-ft'model = AutoModelForCausalLM.from_pretrained(checkpoints, trust_remote_code=True).to(device)processor = AutoProcessor.from_pretrained(checkpoints, trust_remote_code=True)
As in post Florence-2 we create a function to request answers to the model
from transformers import AutoModelForCausalLM, AutoProcessorimport torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")checkpoints = 'microsoft/Florence-2-base-ft'model = AutoModelForCausalLM.from_pretrained(checkpoints, trust_remote_code=True).to(device)processor = AutoProcessor.from_pretrained(checkpoints, trust_remote_code=True)def create_prompt(task_prompt, text_input=None):if text_input is None:prompt = task_promptelse:prompt = task_prompt + text_inputreturn prompt
from transformers import AutoModelForCausalLM, AutoProcessorimport torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")checkpoints = 'microsoft/Florence-2-base-ft'model = AutoModelForCausalLM.from_pretrained(checkpoints, trust_remote_code=True).to(device)processor = AutoProcessor.from_pretrained(checkpoints, trust_remote_code=True)def create_prompt(task_prompt, text_input=None):if text_input is None:prompt = task_promptelse:prompt = task_prompt + text_inputreturn promptdef generate_answer(task_prompt, text_input=None, image=None, device="cpu"):# Create promptprompt = create_prompt(task_prompt, text_input)# Ensure the image is in RGB modeif image.mode != "RGB":image = image.convert("RGB")# Get inputsinputs = processor(text=prompt, images=image, return_tensors="pt").to(device)# Get outputsgenerated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=1024,early_stopping=False,do_sample=False,num_beams=3,)# Decode the generated IDsgenerated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]# Post-process the generated textparsed_answer = processor.post_process_generation(generated_text,task=task_prompt,image_size=(image.width, image.height))return parsed_answer
We test the model with 3 documents from the dataset, with the DocVQA
task to see if we get anything.
for idx in range(3):
print(generate_answer(task_prompt="<DocVQA>", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))
display(data_train[idx]['image'].resize([350, 350]))
for idx in range(3):
print(generate_answer(task_prompt="DocVQA", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))
display(data_train[idx]['image'].resize([350, 350]))
We see that the answers are not good
We now try with the OCR
task
for idx in range(3):
print(generate_answer(task_prompt="<OCR>", image=data_train[idx]['image'], device=model.device))
display(data_train[idx]['image'].resize([350, 350]))
We get the text of the documents, but not what the documents are about.
Finally, we test with the CAPTION
tasks
for idx in range(3):
print(generate_answer(task_prompt="<CAPTION>", image=data_train[idx]['image'], device=model.device))
print(generate_answer(task_prompt="<DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))
print(generate_answer(task_prompt="<MORE_DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))
display(data_train[idx]['image'].resize([350, 350]))
These answers are not enough for us either, so we are going to do the fine tuning.
Fine tuning
First we create a Pytorch dataset
from transformers import AutoModelForCausalLM, AutoProcessorimport torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")checkpoints = 'microsoft/Florence-2-base-ft'model = AutoModelForCausalLM.from_pretrained(checkpoints, trust_remote_code=True).to(device)processor = AutoProcessor.from_pretrained(checkpoints, trust_remote_code=True)def create_prompt(task_prompt, text_input=None):if text_input is None:prompt = task_promptelse:prompt = task_prompt + text_inputreturn promptdef generate_answer(task_prompt, text_input=None, image=None, device="cpu"):# Create promptprompt = create_prompt(task_prompt, text_input)# Ensure the image is in RGB modeif image.mode != "RGB":image = image.convert("RGB")# Get inputsinputs = processor(text=prompt, images=image, return_tensors="pt").to(device)# Get outputsgenerated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=1024,early_stopping=False,do_sample=False,num_beams=3,)# Decode the generated IDsgenerated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]# Post-process the generated textparsed_answer = processor.post_process_generation(generated_text,task=task_prompt,image_size=(image.width, image.height))return parsed_answerfor idx in range(3):print(generate_answer(task_prompt="<DocVQA>", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="DocVQA", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="<OCR>", image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="<CAPTION>", image=data_train[idx]['image'], device=model.device))print(generate_answer(task_prompt="<DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))print(generate_answer(task_prompt="<MORE_DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))from torch.utils.data import Datasetclass DocVQADataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):example = self.data[idx]question = "<DocVQA>" + example['question']first_answer = example['answers'][0]image = example['image']if image.mode != "RGB":image = image.convert("RGB")return question, first_answer, image
from transformers import AutoModelForCausalLM, AutoProcessorimport torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")checkpoints = 'microsoft/Florence-2-base-ft'model = AutoModelForCausalLM.from_pretrained(checkpoints, trust_remote_code=True).to(device)processor = AutoProcessor.from_pretrained(checkpoints, trust_remote_code=True)def create_prompt(task_prompt, text_input=None):if text_input is None:prompt = task_promptelse:prompt = task_prompt + text_inputreturn promptdef generate_answer(task_prompt, text_input=None, image=None, device="cpu"):# Create promptprompt = create_prompt(task_prompt, text_input)# Ensure the image is in RGB modeif image.mode != "RGB":image = image.convert("RGB")# Get inputsinputs = processor(text=prompt, images=image, return_tensors="pt").to(device)# Get outputsgenerated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=1024,early_stopping=False,do_sample=False,num_beams=3,)# Decode the generated IDsgenerated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]# Post-process the generated textparsed_answer = processor.post_process_generation(generated_text,task=task_prompt,image_size=(image.width, image.height))return parsed_answerfor idx in range(3):print(generate_answer(task_prompt="<DocVQA>", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="DocVQA", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="<OCR>", image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="<CAPTION>", image=data_train[idx]['image'], device=model.device))print(generate_answer(task_prompt="<DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))print(generate_answer(task_prompt="<MORE_DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))from torch.utils.data import Datasetclass DocVQADataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):example = self.data[idx]question = "<DocVQA>" + example['question']first_answer = example['answers'][0]image = example['image']if image.mode != "RGB":image = image.convert("RGB")return question, first_answer, imagetrain_dataset = DocVQADataset(subset_data_train)val_dataset = DocVQADataset(subset_data_validation)
Let's see it
from transformers import AutoModelForCausalLM, AutoProcessorimport torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")checkpoints = 'microsoft/Florence-2-base-ft'model = AutoModelForCausalLM.from_pretrained(checkpoints, trust_remote_code=True).to(device)processor = AutoProcessor.from_pretrained(checkpoints, trust_remote_code=True)def create_prompt(task_prompt, text_input=None):if text_input is None:prompt = task_promptelse:prompt = task_prompt + text_inputreturn promptdef generate_answer(task_prompt, text_input=None, image=None, device="cpu"):# Create promptprompt = create_prompt(task_prompt, text_input)# Ensure the image is in RGB modeif image.mode != "RGB":image = image.convert("RGB")# Get inputsinputs = processor(text=prompt, images=image, return_tensors="pt").to(device)# Get outputsgenerated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=1024,early_stopping=False,do_sample=False,num_beams=3,)# Decode the generated IDsgenerated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]# Post-process the generated textparsed_answer = processor.post_process_generation(generated_text,task=task_prompt,image_size=(image.width, image.height))return parsed_answerfor idx in range(3):print(generate_answer(task_prompt="<DocVQA>", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="DocVQA", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="<OCR>", image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="<CAPTION>", image=data_train[idx]['image'], device=model.device))print(generate_answer(task_prompt="<DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))print(generate_answer(task_prompt="<MORE_DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))from torch.utils.data import Datasetclass DocVQADataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):example = self.data[idx]question = "<DocVQA>" + example['question']first_answer = example['answers'][0]image = example['image']if image.mode != "RGB":image = image.convert("RGB")return question, first_answer, imagetrain_dataset = DocVQADataset(subset_data_train)val_dataset = DocVQADataset(subset_data_validation)train_dataset[0]
{'<DocVQA>': 'docvQA'}{'DocVQA': 'unanswerable'}{'<OCR>': 'ConfidentialDATE:11/8/18RJT FR APPROVALBUBJECT: Rl gdasPROPOSED RELEASE DATE:for responseFOR RELEASE TO!CONTRACT: P. CARTERROUTE TO!NameIntiifnPeggy CarterAce11/fesMura PayneDavid Fishhel037Tom Gisis Com-Diane BarrowsEd BlackmerTow KuckerReturn to Peggy Carter, PR, 16 Raynolds BuildingLLS. 2015Source: https://www.industrydocuments.ucsf.edu/docs/xnbl0037'}{'<CAPTION>': 'A certificate is stamped with the date of 18/18.'}{'<DETAILED_CAPTION>': 'In this image we can see a paper with some text on it.'}{'<MORE_DETAILED_CAPTION>': 'A letter is written in black ink on a white paper. The letters are written in a cursive language. The letter is addressed to peggy carter. '}('<DocVQA>what is the date mentioned in this letter?','1/8/93',<PIL.Image.Image image mode=RGB size=1695x2025>)
data_train[0]
{'questionId': 337,'question': 'what is the date mentioned in this letter?','question_types': ['handwritten', 'form'],'image': <PIL.PngImagePlugin.PngImageFile image mode=L size=1695x2025>,'docId': 279,'ucsf_document_id': 'xnbl0037','ucsf_document_page_no': '1','answers': ['1/8/93']}
We create a dataloader
import osfrom torch.utils.data import DataLoaderfrom tqdm import tqdmfrom transformers import (AdamW, AutoProcessor, get_scheduler)def collate_fn(batch):questions, answers, images = zip(*batch)inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(device)return inputs, answers# Create DataLoaderbatch_size = 8num_workers = 0train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers)
Let's see a sample
import osfrom torch.utils.data import DataLoaderfrom tqdm import tqdmfrom transformers import (AdamW, AutoProcessor, get_scheduler)def collate_fn(batch):questions, answers, images = zip(*batch)inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(device)return inputs, answers# Create DataLoaderbatch_size = 8num_workers = 0train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers)sample = next(iter(train_loader))
import osfrom torch.utils.data import DataLoaderfrom tqdm import tqdmfrom transformers import (AdamW, AutoProcessor, get_scheduler)def collate_fn(batch):questions, answers, images = zip(*batch)inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(device)return inputs, answers# Create DataLoaderbatch_size = 8num_workers = 0train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers)sample = next(iter(train_loader))sample
({'input_ids': tensor([[ 0, 41552, 42291, 846, 1864, 250, 15698, 12375, 16, 5,3383, 9, 331, 9, 2042, 116, 2, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 5,11968, 196, 205, 22922, 346, 17487, 2, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 5,1229, 13, 403, 690, 116, 2, 1, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 5,5480, 1280, 116, 2, 1, 1, 1, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 12196, 16, 5,1842, 346, 13, 20, 4680, 41828, 42237, 8, 30147, 17487,2, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 560, 61, 675,473, 42, 1013, 266, 9943, 7, 116, 2, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 12196, 16, 5,1280, 9, 39432, 642, 6228, 2394, 2801, 11, 5, 576,266, 17487, 2],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 1982,11, 5, 6655, 2325, 23, 5, 299, 235, 9, 5,3780, 116, 2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],...'97.00','123','1 January 1979 - 31 December 1979','$2,720.14','GPI'))
The raw sample is a lot of information, so let's look at the length of the sample.
len(sample)
2
We obtain a length of 2 because we have the input to the model and the response.
sample_inputs = sample[0]sample_answers = sample[1]
We see the entrance
sample_inputs = sample[0]sample_answers = sample[1]sample_inputs
{'input_ids': tensor([[ 0, 41552, 42291, 846, 1864, 250, 15698, 12375, 16, 5,3383, 9, 331, 9, 2042, 116, 2, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 5,11968, 196, 205, 22922, 346, 17487, 2, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 5,1229, 13, 403, 690, 116, 2, 1, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 5,5480, 1280, 116, 2, 1, 1, 1, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 12196, 16, 5,1842, 346, 13, 20, 4680, 41828, 42237, 8, 30147, 17487,2, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 560, 61, 675,473, 42, 1013, 266, 9943, 7, 116, 2, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 12196, 16, 5,1280, 9, 39432, 642, 6228, 2394, 2801, 11, 5, 576,266, 17487, 2],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 1982,11, 5, 6655, 2325, 23, 5, 299, 235, 9, 5,3780, 116, 2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],...[ 2.6400, 2.6400, 2.6400, ..., 1.3502, 0.7925, 1.3502],[ 2.6400, 2.6400, 2.6400, ..., 0.9319, 1.4025, 0.8448],[ 2.6400, 2.6400, 2.6400, ..., 1.0365, 1.2282, 0.8099]]]])}
The raw entry also has too much information, so let's take a look at the keys.
sample_inputs.keys()
dict_keys(['input_ids', 'attention_mask', 'pixel_values'])
As we can see we have the input_ids
and the attention_mask
that correspond to the input text and the pixel_values
that correspond to the image. Let's see the dimension of each one
sample_inputs['input_ids'].shape, sample_inputs['attention_mask'].shape, sample_inputs['pixel_values'].shape
(torch.Size([8, 23]), torch.Size([8, 23]), torch.Size([8, 3, 768, 768]))
In the input_ids
and attention_mask
each element has 28 tokens and in the pixel_values
each element has 3 channels, 768 pixels height and 768 pixels width.
Now let's see the answers
sample_answers
('JAMES A. RHODES','1-800-992-3284','$50,000','97.00','123','1 January 1979 - 31 December 1979','$2,720.14','GPI')
We have obtained 8 responses, for the same reason as before, because when we created the dataloader we set a batch size of 8.
len(sample_answers)
8
We create a function to do fine tuning
def train_model(train_loader, val_loader, model, processor, epochs=10, lr=1e-6):optimizer = AdamW(model.parameters(), lr=lr)num_training_steps = epochs * len(train_loader)lr_scheduler = get_scheduler(name="linear",optimizer=optimizer,num_warmup_steps=0,num_training_steps=num_training_steps,)for epoch in range(epochs):# Training phaseprint(f" Training Epoch {epoch + 1}/{epochs}")model.train()train_loss = 0i = -1for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}"):i += 1inputs, answers = batchinput_ids = inputs["input_ids"]pixel_values = inputs["pixel_values"]labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)loss = outputs.lossloss.backward()optimizer.step()lr_scheduler.step()optimizer.zero_grad()train_loss += loss.item()avg_train_loss = train_loss / len(train_loader)print(f"Average Training Loss: {avg_train_loss}")# Validation phasemodel.eval()val_loss = 0with torch.no_grad():for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{epochs}"):inputs, answers = batchinput_ids = inputs["input_ids"]pixel_values = inputs["pixel_values"]labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)loss = outputs.lossval_loss += loss.item()avg_val_loss = val_loss / len(val_loader)print(f"Average Validation Loss: {avg_val_loss}")
We train
train_model(train_loader, val_loader, model, processor, epochs=3, lr=1e-6)
Test the fine tuned model
We now test the model on a few documents of the test set
for idx in range(3):
print(generate_answer(task_prompt="<DocVQA>", text_input='What do you see in this image?', image=data_test[idx]['image'], device=model.device))
display(data_test[idx]['image'].resize([350, 350]))
We see that it gives us information
We are now going to test the test set again, to compare it with what it came out before the training.
for idx in range(3):
print(generate_answer(task_prompt="<DocVQA>", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))
display(data_train[idx]['image'].resize([350, 350]))
It does not give very good results, but we have only trained 3 epochs. Although it could be improved by training more, what you can see is that when we used the <DocVQA>
task tag before we didn't get a response, but now we do.