[NLP] get_accuracy 함수 내부구현 이해

2025. 1. 6. 20:54·NLP(자연어처리)

 

모델이 예측한 값과 실제 값을 비교하여 정확도 계산

collator = DataCollatorWithPadding(tokenizer)
loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collator, shuffle=False)

 

  • DataLoader
    • 데이터를 배치(batch)로 나누어 모델에 입력
    • (데이터 셋이 너무 커서 한번에 처리 할 수 없을 때, 배치 단위로 모델에 전달) 
  • collate_fn
    • 데이터를 패딩 처리하는 역할
    • DataCollatorWithPadding을 사용하면 배치 내에서 길이가 다른 문장들이 있을 경우, 가장 긴 길이에 맞춰 나머지 문장들을 padding 해준다. 
with torch.inference_mode() :
    outputs = model(**sample)
    print("# outputs.logits.shape", outputs.logits.shape)

 

  • logits : 모델의 출력이자 예측된 확률 값 (batch_size , seq_len , vocab_size) 크기를 가진다. 
    • batch_size : 한번에 처리하는 샘플 수  
    • seq_len : 문장의 길이 (토큰의 개수)
    • vocab_size : 모델이 사용할 수 있는 단어의 수 (어휘의 크기) 
logits = outputs.logits[torch.arange(len(lengths) , device = lengths.device) , lenghts-1]

 

각 샘플의 마지막 토큰에 해당하는 로짓만 출력 ( lengths-1은 각 샘플의 마지막 토큰 위치) 

 

predictions = logits[...,label_ids]
prediction = predictions.argmax(-1)

 

라벨에 해당하는 로짓출력 및 가장 큰 값을 가진 라벨 선택 

 

trues += (labels == prediction).sum().item()
trues / len(dataset)

 

총 데이터 개수에서 정답을 맞춘 누적합을 나누면 정확도를 얻을 수 있다. 

 

#전체코드
def get_accuracy(model : AutoModelForCasualLM , dataset : Dataset , batch_size : int) :
    collator = DataCollatorWithPadding(tokenizer)
    loader = DataLoader(datset , batch_size = batch_size , collate_fn = collator , shuffle = False) 
    
    trues = 0
    
    for sample in tqdm(loader) :
        sample = sample.to('cuda')
        lengths = sample.pop('length')
        labels = sample.pop('labels')
        
        with torch.inference_mode() :
            outputs = model(**sample)
        logits = outputs.logits[torch.arange(len(lengths) , device = lengths.device) , length-1] 
        prediction = logits[...,label_ids]
        prediction = prediction.argmax(-1) 
        
        trues += (labels == prediction).sum().item() 
        
    return trues / len(datset) 
    
print(get_acuuracy(model , instructions , batch_size = 8))

'NLP(자연어처리)' 카테고리의 다른 글

Quantization & Prompt Engineering  (1) 2025.01.29
[NLP] PEFT(Parameter Efficient Tuning) : LoRA 코드  (0) 2025.01.06
[NLP] R.A.G 기법  (0) 2024.07.20
'NLP(자연어처리)' 카테고리의 다른 글
  • Quantization & Prompt Engineering
  • [NLP] PEFT(Parameter Efficient Tuning) : LoRA 코드
  • [NLP] R.A.G 기법
zzoming_00
zzoming_00
꾸준함🍀
  • zzoming_00
    ZZOMING'S TECH BLOG
    zzoming_00
  • 전체
    오늘
    어제
    • ALL (56)
      • Docker (0)
      • 경진대회 (2)
      • 사이드 프로젝트 (5)
      • NLP(자연어처리) (4)
      • CV(컴퓨터비전) (2)
      • ML&DL (9)
      • Git (4)
      • Python (10)
      • Algorithm (19)
  • 블로그 메뉴

    • 홈
    • 글쓰기
    • 태그
    • 방명록
  • 링크

    • 글쓰기
  • 공지사항

  • 인기 글

  • 태그

    챗봇
    참조방식
    알고리즘
    내장함수
    고차함수
    파이썬
    LLM
    클래스
    object
  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
zzoming_00
[NLP] get_accuracy 함수 내부구현 이해
상단으로

티스토리툴바