Files
WeChatMsg/MemoAI/qwen2-0.5b/train.ipynb
睿 安 abba5cb273 init
2026-01-21 16:48:36 +08:00

420 lines
11 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "de53995b-32ed-4722-8cac-ba104c8efacb",
"metadata": {},
"source": [
"# 导入环境"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "52fac949-4150-4091-b0c3-2968ab5e385c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from datasets import Dataset\n",
"import pandas as pd\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e098d9eb",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"tags": []
},
"outputs": [],
"source": [
"df = pd.read_json('train.json')\n",
"ds = Dataset.from_pandas(df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8ac92d42-efae-49b1-a00e-ccaa75b98938",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"ds[:3]"
]
},
{
"cell_type": "markdown",
"id": "380d9f69-9e98-4d2d-b044-1d608a057b0b",
"metadata": {},
"source": [
"# 下载模型"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "312d6439-1932-44a3-b592-9adbdb7ab702",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from modelscope import snapshot_download\n",
"model_dir = snapshot_download('qwen/Qwen2-0.5B-Instruct', cache_dir='qwen2-0.5b/')"
]
},
{
"cell_type": "markdown",
"id": "51d05e5d-d14e-4f03-92be-9a9677d41918",
"metadata": {},
"source": [
"# 处理数据集"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74ee5a67-2e55-4974-b90e-cbf492de500a",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"tags": []
},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained('./qwen2-0.5b/qwen/Qwen2-0___5B-Instruct/', use_fast=False, trust_remote_code=True)\n",
"tokenizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2503a5fa-9621-4495-9035-8e7ef6525691",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def process_func(example):\n",
" MAX_LENGTH = 384 # Llama分词器会将一个中文字切分为多个token因此需要放开一些最大长度保证数据的完整性\n",
" input_ids, attention_mask, labels = [], [], []\n",
" instruction = tokenizer(f\"<|im_start|>system\\n现在你需要扮演我,和我的微信好友快乐聊天!<|im_end|>\\n<|im_start|>user\\n{example['instruction'] + example['input']}<|im_end|>\\n<|im_start|>assistant\\n\", add_special_tokens=False)\n",
" response = tokenizer(f\"{example['output']}\", add_special_tokens=False)\n",
" input_ids = instruction[\"input_ids\"] + response[\"input_ids\"] + [tokenizer.pad_token_id]\n",
" attention_mask = instruction[\"attention_mask\"] + response[\"attention_mask\"] + [1] # 因为eos token咱们也是要关注的所以 补充为1\n",
" labels = [-100] * len(instruction[\"input_ids\"]) + response[\"input_ids\"] + [tokenizer.pad_token_id] \n",
" if len(input_ids) > MAX_LENGTH: # 做一个截断\n",
" input_ids = input_ids[:MAX_LENGTH]\n",
" attention_mask = attention_mask[:MAX_LENGTH]\n",
" labels = labels[:MAX_LENGTH]\n",
" return {\n",
" \"input_ids\": input_ids,\n",
" \"attention_mask\": attention_mask,\n",
" \"labels\": labels\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "84f870d6-73a9-4b0f-8abf-687b32224ad8",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"tokenized_id = ds.map(process_func, remove_columns=ds.column_names)\n",
"tokenized_id"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f7e15a0-4d9a-4935-9861-00cc472654b1",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"tokenizer.decode(tokenized_id[0]['input_ids'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97f16f66-324a-454f-8cc3-ef23b100ecff",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"tokenizer.decode(list(filter(lambda x: x != -100, tokenized_id[1][\"labels\"])))"
]
},
{
"cell_type": "markdown",
"id": "424823a8-ed0d-4309-83c8-3f6b1cdf274c",
"metadata": {},
"source": [
"# 创建模型"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "170764e5-d899-4ef4-8c53-36f6dec0d198",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"tags": []
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained('./qwen2-0.5b/qwen/Qwen2-0___5B-Instruct', device_map=\"auto\",torch_dtype=torch.bfloat16)\n",
"model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2323eac7-37d5-4288-8bc5-79fac7113402",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model.enable_input_require_grads()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f808b05c-f2cb-48cf-a80d-0c42be6051c7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model.dtype"
]
},
{
"cell_type": "markdown",
"id": "13d71257-3c1c-4303-8ff8-af161ebc2cf1",
"metadata": {},
"source": [
"# lora "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d304ae2-ab60-4080-a80d-19cac2e3ade3",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from peft import LoraConfig, TaskType, get_peft_model\n",
"\n",
"config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM, \n",
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
" inference_mode=False, # 训练模式\n",
" r=8, # Lora 秩\n",
" lora_alpha=32, # Lora alaph具体作用参见 Lora 原理\n",
" lora_dropout=0.1# Dropout 比例\n",
")\n",
"config"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2c2489c5-eaab-4e1f-b06a-c3f914b4bf8e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model = get_peft_model(model, config)\n",
"config"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ebf5482b-fab9-4eb3-ad88-c116def4be12",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model.print_trainable_parameters()"
]
},
{
"cell_type": "markdown",
"id": "ca055683-837f-4865-9c57-9164ba60c00f",
"metadata": {},
"source": [
"# 配置训练参数"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7e76bbff-15fd-4995-a61d-8364dc5e9ea0",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"args = TrainingArguments(\n",
" output_dir=\"./output/\",\n",
" per_device_train_batch_size=4,\n",
" gradient_accumulation_steps=4,\n",
" logging_steps=10,\n",
" num_train_epochs=3,\n",
" learning_rate=1e-4,\n",
" gradient_checkpointing=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f142cb9c-ad99-48e6-ba86-6df198f9ed96",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"trainer = Trainer(\n",
" model=model,\n",
" args=args,\n",
" train_dataset=tokenized_id,\n",
" data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aec9bc36-b297-45af-99e1-d4c4d82be081",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"trainer.train()"
]
},
{
"cell_type": "markdown",
"id": "8abb2327-458e-4e96-ac98-2141b5b97c8e",
"metadata": {},
"source": [
"# 合并加载模型,这里的路径可能有点不太一样,lora_path填写为Output的最后的checkpoint"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd2a415a-a9ad-49ea-877f-243558a83bfc",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"import torch\n",
"from peft import PeftModel\n",
"\n",
"mode_path = './qwen2-0.5b/qwen/Qwen2-0___5B-Instruct'\n",
"lora_path = './output/checkpoint-10' #修改这里\n",
"# 加载tokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(mode_path, trust_remote_code=True)\n",
"\n",
"# 加载模型\n",
"model = AutoModelForCausalLM.from_pretrained(mode_path, device_map=\"auto\",torch_dtype=torch.bfloat16, trust_remote_code=True).eval()\n",
"\n",
"# 加载lora权重\n",
"model = PeftModel.from_pretrained(model, model_id=lora_path)\n",
"\n",
"prompt = \"在干啥呢?\"\n",
"inputs = tokenizer.apply_chat_template([{\"role\": \"user\", \"content\": \"现在你需要扮演我,和我的微信好友快乐聊天!\"},{\"role\": \"user\", \"content\": prompt}],\n",
" add_generation_prompt=True,\n",
" tokenize=True,\n",
" return_tensors=\"pt\",\n",
" return_dict=True\n",
" ).to('cuda')\n",
"\n",
"\n",
"gen_kwargs = {\"max_length\": 2500, \"do_sample\": True, \"top_k\": 1}\n",
"with torch.no_grad():\n",
" outputs = model.generate(**inputs, **gen_kwargs)\n",
" outputs = outputs[:, inputs['input_ids'].shape[1]:]\n",
" print(tokenizer.decode(outputs[0], skip_special_tokens=True))\n",
"\n",
"# 保存合并后的模型和tokenizer\n",
"save_directory = './model_merge'\n",
"\n",
"# 保存模型\n",
"\n",
"model.save_pretrained(save_directory)\n",
"\n",
"# 保存tokenizer\n",
"tokenizer.save_pretrained(save_directory)"
]
},
{
"cell_type": "markdown",
"id": "b67e5e0a-2566-4483-9bce-92b5be8b4b34",
"metadata": {},
"source": [
"# 然后把模型上传到modelscope开始下一步"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dafe4f24-af5c-407e-abbc-eefd9d44cb15",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}