Ajuste fino de Florence-2
Na postagem Florence-2, já explicamos o modelo Florence-2 e como usá-lo. Portanto, nesta postagem, veremos como ajustá-lo.
Este caderno foi traduzido automaticamente para torná-lo acessível a mais pessoas, por favor me avise se você vir algum erro de digitação..
Ajuste fino para o documento VQA
Esse ajuste fino se baseia na postagem de Merve Noyan, Andres Marafioti e Piotr Skalski, Fine-tuning Florence-2 - Microsoft's Cutting-edge Vision Language Models, na qual eles explicam que, embora esse método seja muito completo, ele não permite consultas em documentos, portanto, eles fazem um novo treinamento com o conjunto de dados DocumentVQA.
Conjunto de dados
Primeiro, baixamos o conjunto de dados. Deixei a variável dataset_percentage
para o caso de você não querer fazer o download de tudo.
from datasets import load_datasetdataset_percentage = 100data_train = load_dataset("HuggingFaceM4/DocumentVQA", split=f"train[:{dataset_percentage}%]")data_validation = load_dataset("HuggingFaceM4/DocumentVQA", split=f"validation[:{dataset_percentage}%]")data_test = load_dataset("HuggingFaceM4/DocumentVQA", split=f"test[:{dataset_percentage}%]")data_train, data_validation, data_test
(Dataset({features: ['questionId', 'question', 'question_types', 'image', 'docId', 'ucsf_document_id', 'ucsf_document_page_no', 'answers'],num_rows: 39463}),Dataset({features: ['questionId', 'question', 'question_types', 'image', 'docId', 'ucsf_document_id', 'ucsf_document_page_no', 'answers'],num_rows: 5349}),Dataset({features: ['questionId', 'question', 'question_types', 'image', 'docId', 'ucsf_document_id', 'ucsf_document_page_no', 'answers'],num_rows: 5188}))
Criamos um subconjunto do conjunto de dados caso você queira tornar o treinamento mais rápido; no meu caso, uso 100% dos dados.
percentage = 1subset_data_train = data_train.select(range(int(len(data_train) * percentage)))subset_data_validation = data_validation.select(range(int(len(data_validation) * percentage)))subset_data_test = data_test.select(range(int(len(data_test) * percentage)))print(f"train dataset length: {len(subset_data_train)}, validation dataset length: {len(subset_data_validation)}, test dataset length: {len(subset_data_test)}")
train dataset length: 39463, validation dataset length: 5349, test dataset length: 5188
Também instanciamos o modelo
from transformers import AutoModelForCausalLM, AutoProcessorimport torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")checkpoints = 'microsoft/Florence-2-base-ft'model = AutoModelForCausalLM.from_pretrained(checkpoints, trust_remote_code=True).to(device)processor = AutoProcessor.from_pretrained(checkpoints, trust_remote_code=True)
Como na postagem Florence-2, criamos uma função para solicitar respostas ao modelo.
from transformers import AutoModelForCausalLM, AutoProcessorimport torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")checkpoints = 'microsoft/Florence-2-base-ft'model = AutoModelForCausalLM.from_pretrained(checkpoints, trust_remote_code=True).to(device)processor = AutoProcessor.from_pretrained(checkpoints, trust_remote_code=True)def create_prompt(task_prompt, text_input=None):if text_input is None:prompt = task_promptelse:prompt = task_prompt + text_inputreturn prompt
from transformers import AutoModelForCausalLM, AutoProcessorimport torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")checkpoints = 'microsoft/Florence-2-base-ft'model = AutoModelForCausalLM.from_pretrained(checkpoints, trust_remote_code=True).to(device)processor = AutoProcessor.from_pretrained(checkpoints, trust_remote_code=True)def create_prompt(task_prompt, text_input=None):if text_input is None:prompt = task_promptelse:prompt = task_prompt + text_inputreturn promptdef generate_answer(task_prompt, text_input=None, image=None, device="cpu"):# Create promptprompt = create_prompt(task_prompt, text_input)# Ensure the image is in RGB modeif image.mode != "RGB":image = image.convert("RGB")# Get inputsinputs = processor(text=prompt, images=image, return_tensors="pt").to(device)# Get outputsgenerated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=1024,early_stopping=False,do_sample=False,num_beams=3,)# Decode the generated IDsgenerated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]# Post-process the generated textparsed_answer = processor.post_process_generation(generated_text,task=task_prompt,image_size=(image.width, image.height))return parsed_answer
Testamos o modelo com 3 documentos do conjunto de dados, com a tarefa DocVQA
para ver se conseguimos algo.
for idx in range(3):
print(generate_answer(task_prompt="<DocVQA>", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))
display(data_train[idx]['image'].resize([350, 350]))
for idx in range(3):
print(generate_answer(task_prompt="DocVQA", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))
display(data_train[idx]['image'].resize([350, 350]))
Vemos que as respostas não são boas
Agora tentamos a tarefa OCR
.
for idx in range(3):
print(generate_answer(task_prompt="<OCR>", image=data_train[idx]['image'], device=model.device))
display(data_train[idx]['image'].resize([350, 350]))
Recebemos o texto dos documentos, mas não o assunto dos documentos.
Por fim, tentamos as tarefas CAPTION
.
for idx in range(3):
print(generate_answer(task_prompt="<CAPTION>", image=data_train[idx]['image'], device=model.device))
print(generate_answer(task_prompt="<DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))
print(generate_answer(task_prompt="<MORE_DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))
display(data_train[idx]['image'].resize([350, 350]))
Essas respostas também não funcionarão, portanto, vamos fazer o ajuste fino.
Ajuste fino
Primeiro, criamos um conjunto de dados Pytorch.
from transformers import AutoModelForCausalLM, AutoProcessorimport torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")checkpoints = 'microsoft/Florence-2-base-ft'model = AutoModelForCausalLM.from_pretrained(checkpoints, trust_remote_code=True).to(device)processor = AutoProcessor.from_pretrained(checkpoints, trust_remote_code=True)def create_prompt(task_prompt, text_input=None):if text_input is None:prompt = task_promptelse:prompt = task_prompt + text_inputreturn promptdef generate_answer(task_prompt, text_input=None, image=None, device="cpu"):# Create promptprompt = create_prompt(task_prompt, text_input)# Ensure the image is in RGB modeif image.mode != "RGB":image = image.convert("RGB")# Get inputsinputs = processor(text=prompt, images=image, return_tensors="pt").to(device)# Get outputsgenerated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=1024,early_stopping=False,do_sample=False,num_beams=3,)# Decode the generated IDsgenerated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]# Post-process the generated textparsed_answer = processor.post_process_generation(generated_text,task=task_prompt,image_size=(image.width, image.height))return parsed_answerfor idx in range(3):print(generate_answer(task_prompt="<DocVQA>", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="DocVQA", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="<OCR>", image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="<CAPTION>", image=data_train[idx]['image'], device=model.device))print(generate_answer(task_prompt="<DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))print(generate_answer(task_prompt="<MORE_DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))from torch.utils.data import Datasetclass DocVQADataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):example = self.data[idx]question = "<DocVQA>" + example['question']first_answer = example['answers'][0]image = example['image']if image.mode != "RGB":image = image.convert("RGB")return question, first_answer, image
from transformers import AutoModelForCausalLM, AutoProcessorimport torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")checkpoints = 'microsoft/Florence-2-base-ft'model = AutoModelForCausalLM.from_pretrained(checkpoints, trust_remote_code=True).to(device)processor = AutoProcessor.from_pretrained(checkpoints, trust_remote_code=True)def create_prompt(task_prompt, text_input=None):if text_input is None:prompt = task_promptelse:prompt = task_prompt + text_inputreturn promptdef generate_answer(task_prompt, text_input=None, image=None, device="cpu"):# Create promptprompt = create_prompt(task_prompt, text_input)# Ensure the image is in RGB modeif image.mode != "RGB":image = image.convert("RGB")# Get inputsinputs = processor(text=prompt, images=image, return_tensors="pt").to(device)# Get outputsgenerated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=1024,early_stopping=False,do_sample=False,num_beams=3,)# Decode the generated IDsgenerated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]# Post-process the generated textparsed_answer = processor.post_process_generation(generated_text,task=task_prompt,image_size=(image.width, image.height))return parsed_answerfor idx in range(3):print(generate_answer(task_prompt="<DocVQA>", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="DocVQA", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="<OCR>", image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="<CAPTION>", image=data_train[idx]['image'], device=model.device))print(generate_answer(task_prompt="<DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))print(generate_answer(task_prompt="<MORE_DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))from torch.utils.data import Datasetclass DocVQADataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):example = self.data[idx]question = "<DocVQA>" + example['question']first_answer = example['answers'][0]image = example['image']if image.mode != "RGB":image = image.convert("RGB")return question, first_answer, imagetrain_dataset = DocVQADataset(subset_data_train)val_dataset = DocVQADataset(subset_data_validation)
Vamos ver isso
from transformers import AutoModelForCausalLM, AutoProcessorimport torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")checkpoints = 'microsoft/Florence-2-base-ft'model = AutoModelForCausalLM.from_pretrained(checkpoints, trust_remote_code=True).to(device)processor = AutoProcessor.from_pretrained(checkpoints, trust_remote_code=True)def create_prompt(task_prompt, text_input=None):if text_input is None:prompt = task_promptelse:prompt = task_prompt + text_inputreturn promptdef generate_answer(task_prompt, text_input=None, image=None, device="cpu"):# Create promptprompt = create_prompt(task_prompt, text_input)# Ensure the image is in RGB modeif image.mode != "RGB":image = image.convert("RGB")# Get inputsinputs = processor(text=prompt, images=image, return_tensors="pt").to(device)# Get outputsgenerated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=1024,early_stopping=False,do_sample=False,num_beams=3,)# Decode the generated IDsgenerated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]# Post-process the generated textparsed_answer = processor.post_process_generation(generated_text,task=task_prompt,image_size=(image.width, image.height))return parsed_answerfor idx in range(3):print(generate_answer(task_prompt="<DocVQA>", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="DocVQA", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="<OCR>", image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))for idx in range(3):print(generate_answer(task_prompt="<CAPTION>", image=data_train[idx]['image'], device=model.device))print(generate_answer(task_prompt="<DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))print(generate_answer(task_prompt="<MORE_DETAILED_CAPTION>", image=data_train[idx]['image'], device=model.device))display(data_train[idx]['image'].resize([350, 350]))from torch.utils.data import Datasetclass DocVQADataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):example = self.data[idx]question = "<DocVQA>" + example['question']first_answer = example['answers'][0]image = example['image']if image.mode != "RGB":image = image.convert("RGB")return question, first_answer, imagetrain_dataset = DocVQADataset(subset_data_train)val_dataset = DocVQADataset(subset_data_validation)train_dataset[0]
{'<DocVQA>': 'docvQA'}{'DocVQA': 'unanswerable'}{'<OCR>': 'ConfidentialDATE:11/8/18RJT FR APPROVALBUBJECT: Rl gdasPROPOSED RELEASE DATE:for responseFOR RELEASE TO!CONTRACT: P. CARTERROUTE TO!NameIntiifnPeggy CarterAce11/fesMura PayneDavid Fishhel037Tom Gisis Com-Diane BarrowsEd BlackmerTow KuckerReturn to Peggy Carter, PR, 16 Raynolds BuildingLLS. 2015Source: https://www.industrydocuments.ucsf.edu/docs/xnbl0037'}{'<CAPTION>': 'A certificate is stamped with the date of 18/18.'}{'<DETAILED_CAPTION>': 'In this image we can see a paper with some text on it.'}{'<MORE_DETAILED_CAPTION>': 'A letter is written in black ink on a white paper. The letters are written in a cursive language. The letter is addressed to peggy carter. '}('<DocVQA>what is the date mentioned in this letter?','1/8/93',<PIL.Image.Image image mode=RGB size=1695x2025>)
data_train[0]
{'questionId': 337,'question': 'what is the date mentioned in this letter?','question_types': ['handwritten', 'form'],'image': <PIL.PngImagePlugin.PngImageFile image mode=L size=1695x2025>,'docId': 279,'ucsf_document_id': 'xnbl0037','ucsf_document_page_no': '1','answers': ['1/8/93']}
Criamos um carregador de dados
import osfrom torch.utils.data import DataLoaderfrom tqdm import tqdmfrom transformers import (AdamW, AutoProcessor, get_scheduler)def collate_fn(batch):questions, answers, images = zip(*batch)inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(device)return inputs, answers# Create DataLoaderbatch_size = 8num_workers = 0train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers)
Vamos dar uma olhada em um exemplo
import osfrom torch.utils.data import DataLoaderfrom tqdm import tqdmfrom transformers import (AdamW, AutoProcessor, get_scheduler)def collate_fn(batch):questions, answers, images = zip(*batch)inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(device)return inputs, answers# Create DataLoaderbatch_size = 8num_workers = 0train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers)sample = next(iter(train_loader))
import osfrom torch.utils.data import DataLoaderfrom tqdm import tqdmfrom transformers import (AdamW, AutoProcessor, get_scheduler)def collate_fn(batch):questions, answers, images = zip(*batch)inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(device)return inputs, answers# Create DataLoaderbatch_size = 8num_workers = 0train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers)sample = next(iter(train_loader))sample
({'input_ids': tensor([[ 0, 41552, 42291, 846, 1864, 250, 15698, 12375, 16, 5,3383, 9, 331, 9, 2042, 116, 2, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 5,11968, 196, 205, 22922, 346, 17487, 2, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 5,1229, 13, 403, 690, 116, 2, 1, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 5,5480, 1280, 116, 2, 1, 1, 1, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 12196, 16, 5,1842, 346, 13, 20, 4680, 41828, 42237, 8, 30147, 17487,2, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 560, 61, 675,473, 42, 1013, 266, 9943, 7, 116, 2, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 12196, 16, 5,1280, 9, 39432, 642, 6228, 2394, 2801, 11, 5, 576,266, 17487, 2],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 1982,11, 5, 6655, 2325, 23, 5, 299, 235, 9, 5,3780, 116, 2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],...'97.00','123','1 January 1979 - 31 December 1979','$2,720.14','GPI'))
A amostra bruta contém muitas informações, portanto, vamos dar uma olhada no comprimento da amostra.
len(sample)
2
Obtemos um comprimento de 2 porque temos a entrada para o modelo e a resposta.
sample_inputs = sample[0]sample_answers = sample[1]
Vemos a entrada
sample_inputs = sample[0]sample_answers = sample[1]sample_inputs
{'input_ids': tensor([[ 0, 41552, 42291, 846, 1864, 250, 15698, 12375, 16, 5,3383, 9, 331, 9, 2042, 116, 2, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 5,11968, 196, 205, 22922, 346, 17487, 2, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 5,1229, 13, 403, 690, 116, 2, 1, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 5,5480, 1280, 116, 2, 1, 1, 1, 1, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 12196, 16, 5,1842, 346, 13, 20, 4680, 41828, 42237, 8, 30147, 17487,2, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 560, 61, 675,473, 42, 1013, 266, 9943, 7, 116, 2, 1, 1,1, 1, 1],[ 0, 41552, 42291, 846, 1864, 250, 15698, 12196, 16, 5,1280, 9, 39432, 642, 6228, 2394, 2801, 11, 5, 576,266, 17487, 2],[ 0, 41552, 42291, 846, 1864, 250, 15698, 2264, 16, 1982,11, 5, 6655, 2325, 23, 5, 299, 235, 9, 5,3780, 116, 2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],...[ 2.6400, 2.6400, 2.6400, ..., 1.3502, 0.7925, 1.3502],[ 2.6400, 2.6400, 2.6400, ..., 0.9319, 1.4025, 0.8448],[ 2.6400, 2.6400, 2.6400, ..., 1.0365, 1.2282, 0.8099]]]])}
A entrada bruta também tem informações demais, então vamos dar uma olhada nas chaves.
sample_inputs.keys()
dict_keys(['input_ids', 'attention_mask', 'pixel_values'])
Como podemos ver, temos o input_ids
e o attention_mask
que correspondem ao texto de entrada e os pixel_values
que correspondem à imagem. Vamos dar uma olhada no tamanho de cada um
sample_inputs['input_ids'].shape, sample_inputs['attention_mask'].shape, sample_inputs['pixel_values'].shape
(torch.Size([8, 23]), torch.Size([8, 23]), torch.Size([8, 3, 768, 768]))
Em input_ids
e attention_mask
, cada elemento tem 28 tokens e, em pixel_values
, cada elemento tem 3 canais, 768 pixels de altura e 768 pixels de largura.
Vamos agora dar uma olhada nas respostas
sample_answers
('JAMES A. RHODES','1-800-992-3284','$50,000','97.00','123','1 January 1979 - 31 December 1979','$2,720.14','GPI')
Obtivemos 8 respostas, pelo mesmo motivo de antes, porque quando criamos o carregador de dados, definimos um tamanho de lote de 8.
len(sample_answers)
8
Criamos uma função para fazer o ajuste fino
def train_model(train_loader, val_loader, model, processor, epochs=10, lr=1e-6):optimizer = AdamW(model.parameters(), lr=lr)num_training_steps = epochs * len(train_loader)lr_scheduler = get_scheduler(name="linear",optimizer=optimizer,num_warmup_steps=0,num_training_steps=num_training_steps,)for epoch in range(epochs):# Training phaseprint(f" Training Epoch {epoch + 1}/{epochs}")model.train()train_loss = 0i = -1for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}"):i += 1inputs, answers = batchinput_ids = inputs["input_ids"]pixel_values = inputs["pixel_values"]labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)loss = outputs.lossloss.backward()optimizer.step()lr_scheduler.step()optimizer.zero_grad()train_loss += loss.item()avg_train_loss = train_loss / len(train_loader)print(f"Average Training Loss: {avg_train_loss}")# Validation phasemodel.eval()val_loss = 0with torch.no_grad():for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{epochs}"):inputs, answers = batchinput_ids = inputs["input_ids"]pixel_values = inputs["pixel_values"]labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)loss = outputs.lossval_loss += loss.item()avg_val_loss = val_loss / len(val_loader)print(f"Average Validation Loss: {avg_val_loss}")
Treinamos
train_model(train_loader, val_loader, model, processor, epochs=3, lr=1e-6)
Teste o modelo ajustado
Agora testamos o modelo em alguns documentos do conjunto de testes.
for idx in range(3):
print(generate_answer(task_prompt="<DocVQA>", text_input='What do you see in this image?', image=data_test[idx]['image'], device=model.device))
display(data_test[idx]['image'].resize([350, 350]))
Vemos que ele nos fornece informações
Agora vamos testar o conjunto de teste novamente para comparar com o que obtivemos antes de começarmos.
for idx in range(3):
print(generate_answer(task_prompt="<DocVQA>", text_input='What do you see in this image?', image=data_train[idx]['image'], device=model.device))
display(data_train[idx]['image'].resize([350, 350]))
Ele não apresenta resultados muito bons, mas treinamos apenas 3 épocas. Embora possa ser aprimorado com mais treinamento, o que você pode ver é que, quando usamos a tag de tarefa <DocVQA>
antes, não recebíamos resposta, mas agora recebemos.