SFT Trainer
In [1]:
Copied!
!pip install -q "torch==2.1.2" tensorboard wandb
!pip install -q "torch==2.1.2" tensorboard wandb
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
In [1]:
Copied!
# Install Pytorch & other libraries
!pip install -q "torch==2.1.2" tensorboard
# Install Hugging Face libraries
!pip install -q --upgrade \
"transformers==4.36.2" \
"datasets==2.16.1" \
"accelerate==0.26.1" \
"evaluate==0.4.1" \
"bitsandbytes==0.42.0" \
"trl==0.7.10" \
"peft==0.7.1"
# Install Pytorch & other libraries
!pip install -q "torch==2.1.2" tensorboard
# Install Hugging Face libraries
!pip install -q --upgrade \
"transformers==4.36.2" \
"datasets==2.16.1" \
"accelerate==0.26.1" \
"evaluate==0.4.1" \
"bitsandbytes==0.42.0" \
"trl==0.7.10" \
"peft==0.7.1"
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
In [2]:
Copied!
!pip install flash-attn
!pip install flash-attn
Collecting flash-attn Downloading flash_attn-2.5.0.tar.gz (2.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.5/2.5 MB 12.4 MB/s eta 0:00:0000:0100:01 Preparing metadata (setup.py) ... done Requirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (from flash-attn) (2.1.2) Collecting einops Downloading einops-0.7.0-py3-none-any.whl (44 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 kB 2.2 MB/s eta 0:00:00 Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from flash-attn) (23.0) Collecting ninja Downloading ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl (307 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 307.2/307.2 kB 31.4 MB/s eta 0:00:00 Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (12.1.3.1) Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (11.0.2.54) Requirement already satisfied: triton==2.1.0 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (2.1.0) Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (12.1.105) Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (2.18.1) Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (12.1.105) Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (8.9.2.26) Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.1.2) Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (12.1.105) Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.1) Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (2023.10.0) Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (10.3.2.106) Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (12.1.0.106) Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (4.5.0) Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (12.1.105) Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (11.4.5.107) Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (1.12) Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch->flash-attn) (3.9.0) Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/conda/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch->flash-attn) (12.3.101) Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch->flash-attn) (2.1.1) Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch->flash-attn) (1.3.0) Building wheels for collected packages: flash-attn Building wheel for flash-attn (setup.py) ... done Created wheel for flash-attn: filename=flash_attn-2.5.0-cp310-cp310-linux_x86_64.whl size=120823033 sha256=3335e74258645eb190597754d42c2fee391fbdeb772847f9e1de12da60450a33 Stored in directory: /root/.cache/pip/wheels/9e/c3/22/a576eb5627fb2c30dc4679a33d67d34d922d6dbeb24a9119b2 Successfully built flash-attn Installing collected packages: ninja, einops, flash-attn Successfully installed einops-0.7.0 flash-attn-2.5.0 ninja-1.11.1.1 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
In [4]:
Copied!
!git config --global credential.helper store
!git config --global credential.helper store
In [5]:
Copied!
from huggingface_hub import login
login(
token="", # ADD YOUR TOKEN HERE
add_to_git_credential=True
)
from huggingface_hub import login
login(
token="", # ADD YOUR TOKEN HERE
add_to_git_credential=True
)
Token is valid (permission: write). Your token has been saved in your configured git credential helpers (store). Your token has been saved to /root/.cache/huggingface/token Login successful
In [1]:
Copied!
from datasets import load_dataset
def create_conversation(sample):
return {
"messages": [
{"role": "system", "content": sample["task"]},
{"role": "user", "content": sample["query"]},
{"role": "assistant", "content": sample["pos"]}
]
}
# Load dataset from the hub
dataset = load_dataset("TokenBender/sentence_retrieval_hindi_SFT", split="train")
dataset = dataset.shuffle().select(range(12500))
# Convert dataset to OAI messages
dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)
# split dataset into 10,000 training samples and 2,500 test samples
dataset = dataset.train_test_split(test_size=2500/12500)
print(dataset["train"][345]["messages"])
# save datasets to disk
dataset["train"].to_json("train_dataset.json", orient="records")
dataset["test"].to_json("test_dataset.json", orient="records")
from datasets import load_dataset
def create_conversation(sample):
return {
"messages": [
{"role": "system", "content": sample["task"]},
{"role": "user", "content": sample["query"]},
{"role": "assistant", "content": sample["pos"]}
]
}
# Load dataset from the hub
dataset = load_dataset("TokenBender/sentence_retrieval_hindi_SFT", split="train")
dataset = dataset.shuffle().select(range(12500))
# Convert dataset to OAI messages
dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)
# split dataset into 10,000 training samples and 2,500 test samples
dataset = dataset.train_test_split(test_size=2500/12500)
print(dataset["train"][345]["messages"])
# save datasets to disk
dataset["train"].to_json("train_dataset.json", orient="records")
dataset["test"].to_json("test_dataset.json", orient="records")
Map: 0%| | 0/12500 [00:00<?, ? examples/s]
[{'content': 'दक्षिण पूर्व एशिया में यात्रा के अनुभवों के बारे में ब्लॉग पोस्ट खोजें।', 'role': 'system'}, {'content': 'मैंने हाल ही में दक्षिण पूर्व एशिया की यात्रा की और बैंकॉक में जीवंत स्ट्रीट फूड दृश्य की खोज करने, हनोई के समृद्ध इतिहास में खुद को डुबोने और बाली में छिपे हुए रत्नों की खोज करने में एक अद्भुत समय बिताया। अंदरूनी सुझावों और सिफारिशों के लिए मेरे ब्लॉग पोस्ट को देखें!', 'role': 'user'}, {'content': 'अद्वितीय आराम और लुभावने दृश्यों की पेशकश करते हुए दक्षिण पूर्व एशिया में शीर्ष 10 लक्जरी रिसॉर्ट्स की खोज करें। इन विशिष्ट गंतव्यों में विश्व स्तरीय सुविधाओं और लाड़-प्यार की सेवाओं में शामिल हों, जो इस क्षेत्र में एक आरामदायक सैर के लिए एकदम सही हैं।', 'role': 'assistant'}]
Creating json from Arrow format: 0%| | 0/10 [00:00<?, ?ba/s]
Creating json from Arrow format: 0%| | 0/3 [00:00<?, ?ba/s]
Out[1]:
12038569
In [2]:
Copied!
from datasets import load_dataset
# Load jsonl data from disk
dataset = load_dataset("json", data_files="train_dataset.json", split="train")
from datasets import load_dataset
# Load jsonl data from disk
dataset = load_dataset("json", data_files="train_dataset.json", split="train")
Generating train split: 0 examples [00:00, ? examples/s]
In [3]:
Copied!
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import setup_chat_format
import wandb
wandb.login()
%env WANDB_PROJECT=hindi_sft_test_tinyllama
# Hugging Face model id
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # or `mistralai/Mistral-7B-v0.1`
# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.padding_side = 'right' # to prevent warnings
# # set chat template to OAI chatML, remove if you start from a fine-tuned model
model, tokenizer = setup_chat_format(model, tokenizer)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import setup_chat_format
import wandb
wandb.login()
%env WANDB_PROJECT=hindi_sft_test_tinyllama
# Hugging Face model id
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # or `mistralai/Mistral-7B-v0.1`
# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.padding_side = 'right' # to prevent warnings
# # set chat template to OAI chatML, remove if you start from a fine-tuned model
model, tokenizer = setup_chat_format(model, tokenizer)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving. wandb: Currently logged in as: ahm-rimer. Use `wandb login --relogin` to force relogin
env: WANDB_PROJECT=hindi_sft_test_tinyllama
In [4]:
Copied!
from peft import LoraConfig
# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.05,
r=32,
bias="none",
target_modules=['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj'],
task_type="CAUSAL_LM",
)
from peft import LoraConfig
# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.05,
r=32,
bias="none",
target_modules=['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj'],
task_type="CAUSAL_LM",
)
In [7]:
Copied!
from transformers import TrainingArguments
args = TrainingArguments(
output_dir="tinyllama_hindi_sentence_retrieval_sft", # directory to save and repository id
num_train_epochs=1, # number of training epochs
per_device_train_batch_size=4, # batch size per device during training
gradient_accumulation_steps=2, # number of steps before performing a backward/update pass
gradient_checkpointing=True, # use gradient checkpointing to save memory
optim="adamw_torch_fused", # use fused adamw optimizer
logging_steps=1, # log every 10 steps
save_steps=0.3, # save checkpoint every epoch
learning_rate=2e-4, # learning rate, based on QLoRA paper
bf16=True, # use bfloat16 precision
tf32=True, # use tf32 precision
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
warmup_ratio=0.03, # warmup ratio based on QLoRA paper
lr_scheduler_type="constant", # use constant learning rate scheduler
push_to_hub=True, # push model to hub
report_to="wandb", # report metrics to wandb
)
from transformers import TrainingArguments
args = TrainingArguments(
output_dir="tinyllama_hindi_sentence_retrieval_sft", # directory to save and repository id
num_train_epochs=1, # number of training epochs
per_device_train_batch_size=4, # batch size per device during training
gradient_accumulation_steps=2, # number of steps before performing a backward/update pass
gradient_checkpointing=True, # use gradient checkpointing to save memory
optim="adamw_torch_fused", # use fused adamw optimizer
logging_steps=1, # log every 10 steps
save_steps=0.3, # save checkpoint every epoch
learning_rate=2e-4, # learning rate, based on QLoRA paper
bf16=True, # use bfloat16 precision
tf32=True, # use tf32 precision
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
warmup_ratio=0.03, # warmup ratio based on QLoRA paper
lr_scheduler_type="constant", # use constant learning rate scheduler
push_to_hub=True, # push model to hub
report_to="wandb", # report metrics to wandb
)
In [8]:
Copied!
from trl import SFTTrainer
max_seq_length = 2048 # max sequence length for model and packing of the dataset
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=dataset,
peft_config=peft_config,
max_seq_length=max_seq_length,
tokenizer=tokenizer,
packing=True,
dataset_kwargs={
"add_special_tokens": False, # We template with special tokens
"append_concat_token": False, # No need to add additional separator token
}
)
from trl import SFTTrainer
max_seq_length = 2048 # max sequence length for model and packing of the dataset
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=dataset,
peft_config=peft_config,
max_seq_length=max_seq_length,
tokenizer=tokenizer,
packing=True,
dataset_kwargs={
"add_special_tokens": False, # We template with special tokens
"append_concat_token": False, # No need to add additional separator token
}
)
Generating train split: 0 examples [00:00, ? examples/s]
In [9]:
Copied!
# start training, the model will be automatically saved to the hub and the output directory
trainer.train()
# save model
trainer.save_model()
# start training, the model will be automatically saved to the hub and the output directory
trainer.train()
# save model
trainer.save_model()
Changes to your `wandb` environment variables will be ignored because your `wandb` session has already started. For more information on how to modify your settings with `wandb.init()` arguments, please refer to the W&B docs.
Tracking run with wandb version 0.16.2
Run data is saved locally in
/workspace/wandb/run-20240128_191736-13n2gxgh
View project at https://wandb.ai/ahm-rimer/hindi_sft_test_tinyllama
You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding. `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`... /opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants. warnings.warn( The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
[622/622 41:28, Epoch 1/1]
Step | Training Loss |
---|---|
1 | 1.197200 |
2 | 1.141000 |
3 | 1.131400 |
4 | 1.086400 |
5 | 1.089900 |
6 | 1.004200 |
7 | 1.032800 |
8 | 1.062700 |
9 | 1.045000 |
10 | 0.994600 |
11 | 0.979000 |
12 | 0.966600 |
13 | 0.980000 |
14 | 0.914500 |
15 | 0.952300 |
16 | 0.915400 |
17 | 0.941800 |
18 | 0.949200 |
19 | 0.864800 |
20 | 0.937400 |
21 | 0.959400 |
22 | 0.929800 |
23 | 0.892400 |
24 | 0.900700 |
25 | 0.891200 |
26 | 0.910400 |
27 | 0.850800 |
28 | 0.912600 |
29 | 0.832900 |
30 | 0.846400 |
31 | 0.840500 |
32 | 0.856000 |
33 | 0.793800 |
34 | 0.901100 |
35 | 0.871500 |
36 | 0.834300 |
37 | 0.832300 |
38 | 0.810800 |
39 | 0.840100 |
40 | 0.886200 |
41 | 0.823800 |
42 | 0.823300 |
43 | 0.868200 |
44 | 0.851900 |
45 | 0.845500 |
46 | 0.829100 |
47 | 0.826400 |
48 | 0.850900 |
49 | 0.808600 |
50 | 0.832700 |
51 | 0.784200 |
52 | 0.810200 |
53 | 0.785500 |
54 | 0.776400 |
55 | 0.784800 |
56 | 0.796800 |
57 | 0.803300 |
58 | 0.776000 |
59 | 0.829500 |
60 | 0.748200 |
61 | 0.778100 |
62 | 0.757000 |
63 | 0.818700 |
64 | 0.846200 |
65 | 0.811500 |
66 | 0.804400 |
67 | 0.752500 |
68 | 0.768000 |
69 | 0.773200 |
70 | 0.763800 |
71 | 0.725100 |
72 | 0.794800 |
73 | 0.734700 |
74 | 0.732800 |
75 | 0.758000 |
76 | 0.710200 |
77 | 0.781100 |
78 | 0.753400 |
79 | 0.701600 |
80 | 0.758800 |
81 | 0.837000 |
82 | 0.789900 |
83 | 0.775300 |
84 | 0.737000 |
85 | 0.776300 |
86 | 0.755400 |
87 | 0.745100 |
88 | 0.743800 |
89 | 0.693900 |
90 | 0.733400 |
91 | 0.786900 |
92 | 0.766600 |
93 | 0.769400 |
94 | 0.720600 |
95 | 0.730200 |
96 | 0.729800 |
97 | 0.740800 |
98 | 0.767000 |
99 | 0.757500 |
100 | 0.737800 |
101 | 0.728100 |
102 | 0.755200 |
103 | 0.698300 |
104 | 0.711400 |
105 | 0.766700 |
106 | 0.749500 |
107 | 0.705200 |
108 | 0.680300 |
109 | 0.674500 |
110 | 0.706600 |
111 | 0.759000 |
112 | 0.699500 |
113 | 0.709700 |
114 | 0.714800 |
115 | 0.708000 |
116 | 0.700300 |
117 | 0.673500 |
118 | 0.760100 |
119 | 0.694300 |
120 | 0.706500 |
121 | 0.721300 |
122 | 0.698400 |
123 | 0.738900 |
124 | 0.729600 |
125 | 0.696200 |
126 | 0.676000 |
127 | 0.695700 |
128 | 0.729200 |
129 | 0.730000 |
130 | 0.719900 |
131 | 0.726200 |
132 | 0.693100 |
133 | 0.706900 |
134 | 0.708700 |
135 | 0.691700 |
136 | 0.682500 |
137 | 0.727800 |
138 | 0.633700 |
139 | 0.710700 |
140 | 0.653100 |
141 | 0.717000 |
142 | 0.732800 |
143 | 0.677000 |
144 | 0.688600 |
145 | 0.673100 |
146 | 0.678900 |
147 | 0.679900 |
148 | 0.667800 |
149 | 0.643900 |
150 | 0.679000 |
151 | 0.666700 |
152 | 0.695600 |
153 | 0.655300 |
154 | 0.710500 |
155 | 0.659700 |
156 | 0.717600 |
157 | 0.657500 |
158 | 0.657900 |
159 | 0.695600 |
160 | 0.673400 |
161 | 0.642500 |
162 | 0.702800 |
163 | 0.713500 |
164 | 0.674100 |
165 | 0.746000 |
166 | 0.676800 |
167 | 0.669100 |
168 | 0.668800 |
169 | 0.655000 |
170 | 0.684400 |
171 | 0.688200 |
172 | 0.705100 |
173 | 0.669600 |
174 | 0.654800 |
175 | 0.691300 |
176 | 0.640200 |
177 | 0.691600 |
178 | 0.701600 |
179 | 0.718500 |
180 | 0.629500 |
181 | 0.706600 |
182 | 0.661800 |
183 | 0.649300 |
184 | 0.687800 |
185 | 0.623300 |
186 | 0.729500 |
187 | 0.645000 |
188 | 0.723100 |
189 | 0.665900 |
190 | 0.628100 |
191 | 0.707700 |
192 | 0.676500 |
193 | 0.644600 |
194 | 0.658400 |
195 | 0.729700 |
196 | 0.668800 |
197 | 0.672800 |
198 | 0.667000 |
199 | 0.679100 |
200 | 0.656400 |
201 | 0.633200 |
202 | 0.651700 |
203 | 0.648600 |
204 | 0.603300 |
205 | 0.655100 |
206 | 0.637800 |
207 | 0.624800 |
208 | 0.635600 |
209 | 0.640000 |
210 | 0.693500 |
211 | 0.677000 |
212 | 0.625200 |
213 | 0.668800 |
214 | 0.633200 |
215 | 0.643800 |
216 | 0.677900 |
217 | 0.602000 |
218 | 0.616500 |
219 | 0.653500 |
220 | 0.641100 |
221 | 0.624500 |
222 | 0.684600 |
223 | 0.670300 |
224 | 0.675900 |
225 | 0.609500 |
226 | 0.600900 |
227 | 0.642300 |
228 | 0.607700 |
229 | 0.666700 |
230 | 0.613300 |
231 | 0.661400 |
232 | 0.661800 |
233 | 0.627900 |
234 | 0.707200 |
235 | 0.611800 |
236 | 0.611900 |
237 | 0.574400 |
238 | 0.623300 |
239 | 0.681000 |
240 | 0.622300 |
241 | 0.651900 |
242 | 0.614700 |
243 | 0.654900 |
244 | 0.663600 |
245 | 0.670500 |
246 | 0.619700 |
247 | 0.586900 |
248 | 0.644200 |
249 | 0.614600 |
250 | 0.641000 |
251 | 0.633500 |
252 | 0.645700 |
253 | 0.672500 |
254 | 0.635300 |
255 | 0.644100 |
256 | 0.641300 |
257 | 0.569300 |
258 | 0.674100 |
259 | 0.622000 |
260 | 0.659600 |
261 | 0.605200 |
262 | 0.628800 |
263 | 0.606600 |
264 | 0.591900 |
265 | 0.623100 |
266 | 0.604400 |
267 | 0.605600 |
268 | 0.655400 |
269 | 0.695500 |
270 | 0.618400 |
271 | 0.669500 |
272 | 0.641000 |
273 | 0.626000 |
274 | 0.617500 |
275 | 0.620000 |
276 | 0.638700 |
277 | 0.592700 |
278 | 0.648200 |
279 | 0.636100 |
280 | 0.581300 |
281 | 0.557300 |
282 | 0.643300 |
283 | 0.646800 |
284 | 0.625300 |
285 | 0.654400 |
286 | 0.607100 |
287 | 0.593400 |
288 | 0.596900 |
289 | 0.539600 |
290 | 0.620200 |
291 | 0.595400 |
292 | 0.589700 |
293 | 0.642000 |
294 | 0.569100 |
295 | 0.595600 |
296 | 0.594500 |
297 | 0.646400 |
298 | 0.630300 |
299 | 0.658800 |
300 | 0.614100 |
301 | 0.663500 |
302 | 0.649000 |
303 | 0.609400 |
304 | 0.615200 |
305 | 0.628400 |
306 | 0.599600 |
307 | 0.611500 |
308 | 0.605600 |
309 | 0.590200 |
310 | 0.607900 |
311 | 0.627600 |
312 | 0.623900 |
313 | 0.643100 |
314 | 0.609400 |
315 | 0.582000 |
316 | 0.574000 |
317 | 0.600700 |
318 | 0.599200 |
319 | 0.596700 |
320 | 0.620400 |
321 | 0.579700 |
322 | 0.666400 |
323 | 0.576000 |
324 | 0.644500 |
325 | 0.593400 |
326 | 0.624900 |
327 | 0.577800 |
328 | 0.618400 |
329 | 0.586700 |
330 | 0.608200 |
331 | 0.598000 |
332 | 0.580400 |
333 | 0.624300 |
334 | 0.567800 |
335 | 0.593700 |
336 | 0.554100 |
337 | 0.719700 |
338 | 0.551600 |
339 | 0.565500 |
340 | 0.590000 |
341 | 0.591700 |
342 | 0.584800 |
343 | 0.605800 |
344 | 0.641100 |
345 | 0.588000 |
346 | 0.615200 |
347 | 0.567100 |
348 | 0.610200 |
349 | 0.626000 |
350 | 0.610900 |
351 | 0.591800 |
352 | 0.585600 |
353 | 0.599700 |
354 | 0.606800 |
355 | 0.571400 |
356 | 0.612700 |
357 | 0.585900 |
358 | 0.625800 |
359 | 0.642900 |
360 | 0.550300 |
361 | 0.566100 |
362 | 0.604000 |
363 | 0.600600 |
364 | 0.627300 |
365 | 0.521300 |
366 | 0.622500 |
367 | 0.562700 |
368 | 0.577400 |
369 | 0.546600 |
370 | 0.576200 |
371 | 0.582100 |
372 | 0.604100 |
373 | 0.632300 |
374 | 0.626800 |
375 | 0.593400 |
376 | 0.614400 |
377 | 0.566200 |
378 | 0.608800 |
379 | 0.562100 |
380 | 0.564600 |
381 | 0.576500 |
382 | 0.572100 |
383 | 0.573600 |
384 | 0.600700 |
385 | 0.500700 |
386 | 0.618800 |
387 | 0.561100 |
388 | 0.605900 |
389 | 0.579300 |
390 | 0.615000 |
391 | 0.540200 |
392 | 0.561600 |
393 | 0.563700 |
394 | 0.573000 |
395 | 0.597400 |
396 | 0.554300 |
397 | 0.565700 |
398 | 0.620500 |
399 | 0.513900 |
400 | 0.539300 |
401 | 0.609100 |
402 | 0.547700 |
403 | 0.557300 |
404 | 0.585300 |
405 | 0.586300 |
406 | 0.598300 |
407 | 0.547800 |
408 | 0.530200 |
409 | 0.620100 |
410 | 0.568500 |
411 | 0.596900 |
412 | 0.610400 |
413 | 0.587900 |
414 | 0.553600 |
415 | 0.608500 |
416 | 0.519700 |
417 | 0.613200 |
418 | 0.579200 |
419 | 0.613900 |
420 | 0.596300 |
421 | 0.546900 |
422 | 0.589300 |
423 | 0.589900 |
424 | 0.580600 |
425 | 0.584400 |
426 | 0.639800 |
427 | 0.584700 |
428 | 0.596400 |
429 | 0.532800 |
430 | 0.629400 |
431 | 0.560600 |
432 | 0.565700 |
433 | 0.570000 |
434 | 0.595200 |
435 | 0.554300 |
436 | 0.626400 |
437 | 0.611700 |
438 | 0.584300 |
439 | 0.574700 |
440 | 0.611400 |
441 | 0.554900 |
442 | 0.586000 |
443 | 0.594200 |
444 | 0.532100 |
445 | 0.580600 |
446 | 0.590500 |
447 | 0.551300 |
448 | 0.556200 |
449 | 0.566300 |
450 | 0.600100 |
451 | 0.597400 |
452 | 0.526500 |
453 | 0.609900 |
454 | 0.572600 |
455 | 0.629700 |
456 | 0.509900 |
457 | 0.585800 |
458 | 0.569600 |
459 | 0.541300 |
460 | 0.525000 |
461 | 0.543200 |
462 | 0.597100 |
463 | 0.539400 |
464 | 0.566400 |
465 | 0.594900 |
466 | 0.595700 |
467 | 0.530100 |
468 | 0.525500 |
469 | 0.540600 |
470 | 0.577400 |
471 | 0.543700 |
472 | 0.534800 |
473 | 0.607000 |
474 | 0.624600 |
475 | 0.571200 |
476 | 0.500100 |
477 | 0.571600 |
478 | 0.548500 |
479 | 0.546200 |
480 | 0.550800 |
481 | 0.553000 |
482 | 0.541900 |
483 | 0.520500 |
484 | 0.566200 |
485 | 0.573500 |
486 | 0.581800 |
487 | 0.622700 |
488 | 0.547400 |
489 | 0.566500 |
490 | 0.542000 |
491 | 0.544900 |
492 | 0.541100 |
493 | 0.515500 |
494 | 0.587000 |
495 | 0.518900 |
496 | 0.514400 |
497 | 0.545600 |
498 | 0.595700 |
499 | 0.551900 |
500 | 0.539100 |
501 | 0.548600 |
502 | 0.556300 |
503 | 0.523200 |
504 | 0.556300 |
505 | 0.558400 |
506 | 0.508500 |
507 | 0.553200 |
508 | 0.557600 |
509 | 0.572900 |
510 | 0.597800 |
511 | 0.524900 |
512 | 0.529500 |
513 | 0.566900 |
514 | 0.562600 |
515 | 0.546500 |
516 | 0.517900 |
517 | 0.531000 |
518 | 0.571500 |
519 | 0.503300 |
520 | 0.578200 |
521 | 0.598000 |
522 | 0.505400 |
523 | 0.533900 |
524 | 0.527300 |
525 | 0.552600 |
526 | 0.554500 |
527 | 0.534700 |
528 | 0.561500 |
529 | 0.553300 |
530 | 0.509700 |
531 | 0.531900 |
532 | 0.525000 |
533 | 0.571200 |
534 | 0.525800 |
535 | 0.593100 |
536 | 0.545800 |
537 | 0.522400 |
538 | 0.588000 |
539 | 0.556900 |
540 | 0.553500 |
541 | 0.561000 |
542 | 0.546200 |
543 | 0.510300 |
544 | 0.552300 |
545 | 0.526000 |
546 | 0.531100 |
547 | 0.509700 |
548 | 0.482200 |
549 | 0.547000 |
550 | 0.532000 |
551 | 0.534600 |
552 | 0.546000 |
553 | 0.542100 |
554 | 0.518800 |
555 | 0.603500 |
556 | 0.514000 |
557 | 0.538500 |
558 | 0.551000 |
559 | 0.548400 |
560 | 0.542600 |
561 | 0.533900 |
562 | 0.572400 |
563 | 0.556300 |
564 | 0.538900 |
565 | 0.586900 |
566 | 0.518200 |
567 | 0.472500 |
568 | 0.554000 |
569 | 0.530600 |
570 | 0.552300 |
571 | 0.523500 |
572 | 0.586100 |
573 | 0.540100 |
574 | 0.561500 |
575 | 0.540900 |
576 | 0.525000 |
577 | 0.542000 |
578 | 0.605800 |
579 | 0.549400 |
580 | 0.508100 |
581 | 0.523500 |
582 | 0.526300 |
583 | 0.521100 |
584 | 0.525300 |
585 | 0.523600 |
586 | 0.506800 |
587 | 0.547200 |
588 | 0.550000 |
589 | 0.571600 |
590 | 0.539200 |
591 | 0.561000 |
592 | 0.529800 |
593 | 0.488400 |
594 | 0.512300 |
595 | 0.503700 |
596 | 0.520400 |
597 | 0.523200 |
598 | 0.527600 |
599 | 0.569400 |
600 | 0.515700 |
601 | 0.540700 |
602 | 0.504500 |
603 | 0.523900 |
604 | 0.527400 |
605 | 0.539900 |
606 | 0.507100 |
607 | 0.484200 |
608 | 0.525100 |
609 | 0.568100 |
610 | 0.565100 |
611 | 0.535700 |
612 | 0.507300 |
613 | 0.529300 |
614 | 0.543900 |
615 | 0.531400 |
616 | 0.520300 |
617 | 0.527800 |
618 | 0.560800 |
619 | 0.522200 |
620 | 0.491600 |
621 | 0.548300 |
622 | 0.560200 |
/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants. warnings.warn( /opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants. warnings.warn( /opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants. warnings.warn(
adapter_model.safetensors: 0%| | 0.00/50.5M [00:00<?, ?B/s]
In [10]:
Copied!
# free the memory again
del model
del trainer
torch.cuda.empty_cache()
# free the memory again
del model
del trainer
torch.cuda.empty_cache()
In [11]:
Copied!
#### COMMENT IN TO MERGE PEFT AND BASE MODEL ####
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import AutoPeftModelForCausalLM
# # Load PEFT model on CPU
config = PeftConfig.from_pretrained(args.output_dir)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path,low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
model.resize_token_embeddings(len(tokenizer))
model = PeftModel.from_pretrained(model, args.output_dir)
model = AutoPeftModelForCausalLM.from_pretrained(
args.output_dir,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
# # Merge LoRA and base model and save
merged_model = model.merge_and_unload()
merged_model.save_pretrained(args.output_dir,safe_serialization=True, max_shard_size="2GB")
#### COMMENT IN TO MERGE PEFT AND BASE MODEL ####
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import AutoPeftModelForCausalLM
# # Load PEFT model on CPU
config = PeftConfig.from_pretrained(args.output_dir)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path,low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
model.resize_token_embeddings(len(tokenizer))
model = PeftModel.from_pretrained(model, args.output_dir)
model = AutoPeftModelForCausalLM.from_pretrained(
args.output_dir,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
# # Merge LoRA and base model and save
merged_model = model.merge_and_unload()
merged_model.save_pretrained(args.output_dir,safe_serialization=True, max_shard_size="2GB")
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
In [12]:
Copied!
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline
#peft_model_id = "./tinyllama_hindi_sft_sentence_retrieval"
peft_model_id = args.output_dir
# Load Model with PEFT adapter
model = AutoPeftModelForCausalLM.from_pretrained(
peft_model_id,
device_map="auto",
torch_dtype=torch.float16
)
# load into pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline
#peft_model_id = "./tinyllama_hindi_sft_sentence_retrieval"
peft_model_id = args.output_dir
# Load Model with PEFT adapter
model = AutoPeftModelForCausalLM.from_pretrained(
peft_model_id,
device_map="auto",
torch_dtype=torch.float16
)
# load into pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'LlamaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MvpForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonForCausalLM', 'PhiForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'WhisperForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM'].
In [13]:
Copied!
from datasets import load_dataset
from random import randint
# Load our test dataset
eval_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
rand_idx = randint(0, len(eval_dataset))
# Test on sample
prompt = pipe.tokenizer.apply_chat_template(eval_dataset[rand_idx]["messages"][:2], tokenize=False, add_generation_prompt=True)
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)
print(f"Query:\n{eval_dataset[rand_idx]['messages'][1]['content']}")
print(f"Original Answer:\n{eval_dataset[rand_idx]['messages'][2]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")
from datasets import load_dataset
from random import randint
# Load our test dataset
eval_dataset = load_dataset("json", data_files="test_dataset.json", split="train")
rand_idx = randint(0, len(eval_dataset))
# Test on sample
prompt = pipe.tokenizer.apply_chat_template(eval_dataset[rand_idx]["messages"][:2], tokenize=False, add_generation_prompt=True)
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)
print(f"Query:\n{eval_dataset[rand_idx]['messages'][1]['content']}")
print(f"Original Answer:\n{eval_dataset[rand_idx]['messages'][2]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")
Generating train split: 0 examples [00:00, ? examples/s]
/opt/conda/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:389: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. warnings.warn( /opt/conda/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:394: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. warnings.warn( ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [32,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [33,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [34,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [35,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [36,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [37,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [38,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [39,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [40,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [41,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [42,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [43,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [44,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [45,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [46,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [47,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [48,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [49,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [50,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [51,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [52,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [53,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [54,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [55,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [56,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [57,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [58,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [59,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [60,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [61,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [62,0,0] Assertion `srcIndex < srcSelectDimSize` failed. ../aten/src/ATen/native/cuda/Indexing.cu:1292: indexSelectLargeIndex: block: [292,0,0], thread: [63,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[13], line 11 9 # Test on sample 10 prompt = pipe.tokenizer.apply_chat_template(eval_dataset[rand_idx]["messages"][:2], tokenize=False, add_generation_prompt=True) ---> 11 outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id) 13 print(f"Query:\n{eval_dataset[rand_idx]['messages'][1]['content']}") 14 print(f"Original Answer:\n{eval_dataset[rand_idx]['messages'][2]['content']}") File /opt/conda/lib/python3.10/site-packages/transformers/pipelines/text_generation.py:208, in TextGenerationPipeline.__call__(self, text_inputs, **kwargs) 167 def __call__(self, text_inputs, **kwargs): 168 """ 169 Complete the prompt(s) given as inputs. 170 (...) 206 ids of the generated text. 207 """ --> 208 return super().__call__(text_inputs, **kwargs) File /opt/conda/lib/python3.10/site-packages/transformers/pipelines/base.py:1140, in Pipeline.__call__(self, inputs, num_workers, batch_size, *args, **kwargs) 1132 return next( 1133 iter( 1134 self.get_iterator( (...) 1137 ) 1138 ) 1139 else: -> 1140 return self.run_single(inputs, preprocess_params, forward_params, postprocess_params) File /opt/conda/lib/python3.10/site-packages/transformers/pipelines/base.py:1147, in Pipeline.run_single(self, inputs, preprocess_params, forward_params, postprocess_params) 1145 def run_single(self, inputs, preprocess_params, forward_params, postprocess_params): 1146 model_inputs = self.preprocess(inputs, **preprocess_params) -> 1147 model_outputs = self.forward(model_inputs, **forward_params) 1148 outputs = self.postprocess(model_outputs, **postprocess_params) 1149 return outputs File /opt/conda/lib/python3.10/site-packages/transformers/pipelines/base.py:1046, in Pipeline.forward(self, model_inputs, **forward_params) 1044 with inference_context(): 1045 model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device) -> 1046 model_outputs = self._forward(model_inputs, **forward_params) 1047 model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu")) 1048 else: File /opt/conda/lib/python3.10/site-packages/transformers/pipelines/text_generation.py:271, in TextGenerationPipeline._forward(self, model_inputs, **generate_kwargs) 268 generate_kwargs["min_length"] += prefix_length 270 # BS x SL --> 271 generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) 272 out_b = generated_sequence.shape[0] 273 if self.framework == "pt": File /opt/conda/lib/python3.10/site-packages/peft/peft_model.py:1130, in PeftModelForCausalLM.generate(self, **kwargs) 1128 self.base_model.generation_config = self.generation_config 1129 try: -> 1130 outputs = self.base_model.generate(**kwargs) 1131 except: 1132 self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation File /opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs) 112 @functools.wraps(func) 113 def decorate_context(*args, **kwargs): 114 with ctx_factory(): --> 115 return func(*args, **kwargs) File /opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:1718, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs) 1701 return self.assisted_decoding( 1702 input_ids, 1703 assistant_model=assistant_model, (...) 1714 **model_kwargs, 1715 ) 1716 if generation_mode == GenerationMode.GREEDY_SEARCH: 1717 # 11. run greedy search -> 1718 return self.greedy_search( 1719 input_ids, 1720 logits_processor=logits_processor, 1721 stopping_criteria=stopping_criteria, 1722 pad_token_id=generation_config.pad_token_id, 1723 eos_token_id=generation_config.eos_token_id, 1724 output_scores=generation_config.output_scores, 1725 return_dict_in_generate=generation_config.return_dict_in_generate, 1726 synced_gpus=synced_gpus, 1727 streamer=streamer, 1728 **model_kwargs, 1729 ) 1731 elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH: 1732 if not model_kwargs["use_cache"]: File /opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:2579, in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs) 2576 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) 2578 # forward pass to get next token -> 2579 outputs = self( 2580 **model_inputs, 2581 return_dict=True, 2582 output_attentions=output_attentions, 2583 output_hidden_states=output_hidden_states, 2584 ) 2586 if synced_gpus and this_peer_finished: 2587 continue # don't waste resources running the code we don't need File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs) 1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(*args, **kwargs) File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(*args, **kwargs) 1529 try: 1530 result = None File /opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1181, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict) 1178 return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1180 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) -> 1181 outputs = self.model( 1182 input_ids=input_ids, 1183 attention_mask=attention_mask, 1184 position_ids=position_ids, 1185 past_key_values=past_key_values, 1186 inputs_embeds=inputs_embeds, 1187 use_cache=use_cache, 1188 output_attentions=output_attentions, 1189 output_hidden_states=output_hidden_states, 1190 return_dict=return_dict, 1191 ) 1193 hidden_states = outputs[0] 1194 if self.config.pretraining_tp > 1: File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs) 1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(*args, **kwargs) File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(*args, **kwargs) 1529 try: 1530 result = None File /opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1033, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict) 1029 attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None 1030 elif self._use_sdpa and not output_attentions: 1031 # output_attentions=True can not be supported when using SDPA, and we fall back on 1032 # the manual implementation that requires a 4D causal mask in all cases. -> 1033 attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( 1034 attention_mask, 1035 (batch_size, seq_length), 1036 inputs_embeds, 1037 past_key_values_length, 1038 ) 1039 else: 1040 # 4d mask is passed through the layers 1041 attention_mask = _prepare_4d_causal_attention_mask( 1042 attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length 1043 ) File /opt/conda/lib/python3.10/site-packages/transformers/modeling_attn_mask_utils.py:343, in _prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, inputs_embeds, past_key_values_length, sliding_window) 340 is_tracing = torch.jit.is_tracing() 342 if attention_mask is not None: --> 343 if torch.all(attention_mask == 1): 344 if is_tracing: 345 pass RuntimeError: CUDA error: device-side assert triggered CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1. Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
In [ ]:
Copied!