
Text Generation
temperature
Pipeline
from transformers import pipeline
pipe = pipeline(model='.', device=3, task='text-generation')
text = '세계 최초로 달에 착륙한 우주인의 이름은 '
pipe(text, do_sample=True, top_p=0.9, max_new_tokens=256)
pipeline을 쓰지 않으면 아래처럼 hf model의 generate()
를 사용하던지, 직접 generate를 구현(FastChat 처럼)해야 한다.
Tokenizer, Model 직접 호출
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained('.', use_fast=False, add_bos_token=True)
model = AutoModelForCausalLM.from_pretrained('.').to('cuda')
prompt = '<human>: 안녕? <bot>:'
input_ids = tokenizer(prompt, return_tensors='pt').to('cuda').input_ids
outputs = model.generate(input_ids,
do_sample=True,
max_new_tokens=256,
temperature=0.7,
top_p=0.95,
num_return_sequences=3)
# It outputs a variety(three) of results.
for i, output in enumerate(outputs):
print(f'{i}: {tokenizer.decode(output, skip_special_tokens=False)}\n')
max_length
: input prompt +max_new_tokens
num_return_sequences
: 지정한 수 만큼 독립적으로 계산 결과 리턴
입력, 출력
입력:
{
'input_ids': tensor([[50257, 1815, 539, ..., 50258, 50258, 50258]], device='cuda:0'),
'labels': tensor([[-100, -100, -100, ..., -100, -100, -100]], device='cuda:0'),
'attention_mask': tensor([[ True, True, True, ..., False, False, False]], device='cuda:0')
}
labels를 1칸 밀어내지 않고 같은 위치에 둔다. 답변이 아닌 경우에만 -100으로 마스킹한다. 50258 패딩 토큰(50257은 bos 토큰)도 동일하게 -100으로 마스킹 처리하여 loss 계산을 하지 않는다.
logits 값은 outputs = model(**inputs)
으로 간단하게 계산
출력:
alist = []
for i in range(200):
alist.append(inputs.input_ids[0][i])
print(i)
print(tokenizer.decode(alist))
print('NEXT:', tokenizer.decode(torch.topk(outputs.logits[0][i], 1).indices))
print()
logits는 다음 토큰에 대한 값을 출력하며 여기서는 간단하게 topk로 최댓값을 추출했다.
81
<||bos||> 호기심 많은 인간 (human)과 인공지능 봇 (AI bot)의 대화입니다. **** 봇은 인간의 질문에 대해 친절하게 유용하고 상세한 답변을 제공합니다. <human>: 엉겅퀴 제거하는 방법을 자세히 단계별로 설명해줘봐 <bot>:
NEXT: 엉
...
84
<||bos||> 호기심 많은 인간 (human)과 인공지능 봇 (AI bot)의 대화입니다. **** 봇은 인간의 질문에 대해 친절하게 유용하고 상세한 답변을 제공합니다. <human>: 엉겅퀴 제거하는 방법을 자세히 단계별로 설명해줘봐 <bot>: 엉겅
NEXT: 퀴
outputs은 가장 중요한 loss, logits, 이전 어텐션 계산 결과를 캐시하는 past_key_values 등으로 구성된다.
Last Modified: 2023/08/21 17:15:57