init
This commit is contained in:
599
MemoAI/api_server.py
Normal file
599
MemoAI/api_server.py
Normal file
@@ -0,0 +1,599 @@
|
||||
"""
|
||||
This script implements an API for the ChatGLM3-6B model,
|
||||
formatted similarly to OpenAI's API (https://platform.openai.com/docs/api-reference/chat).
|
||||
It's designed to be run as a web server using FastAPI and uvicorn,
|
||||
making the ChatGLM3-6B model accessible through OpenAI Client.
|
||||
|
||||
Key Components and Features:
|
||||
- Model and Tokenizer Setup: Configures the model and tokenizer paths and loads them.
|
||||
- FastAPI Configuration: Sets up a FastAPI application with CORS middleware for handling cross-origin requests.
|
||||
- API Endpoints:
|
||||
- "/v1/models": Lists the available models, specifically ChatGLM3-6B.
|
||||
- "/v1/chat/completions": Processes chat completion requests with options for streaming and regular responses.
|
||||
- "/v1/embeddings": Processes Embedding request of a list of text inputs.
|
||||
- Token Limit Caution: In the OpenAI API, 'max_tokens' is equivalent to HuggingFace's 'max_new_tokens', not 'max_length'.
|
||||
For instance, setting 'max_tokens' to 8192 for a 6b model would result in an error due to the model's inability to output
|
||||
that many tokens after accounting for the history and prompt tokens.
|
||||
- Stream Handling and Custom Functions: Manages streaming responses and custom function calls within chat responses.
|
||||
- Pydantic Models: Defines structured models for requests and responses, enhancing API documentation and type safety.
|
||||
- Main Execution: Initializes the model and tokenizer, and starts the FastAPI app on the designated host and port.
|
||||
|
||||
Note:
|
||||
This script doesn't include the setup for special tokens or multi-GPU support by default.
|
||||
Users need to configure their special tokens and can enable multi-GPU support as per the provided instructions.
|
||||
Embedding Models only support in One GPU.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import tiktoken
|
||||
import torch
|
||||
import uvicorn
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Response, Body
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Literal, Optional, Union
|
||||
from loguru import logger
|
||||
from peft import AutoPeftModelForCausalLM
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
|
||||
from utils import process_response, generate_chatglm3, generate_stream_chatglm3
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
# Set up limit request time
|
||||
EventSourceResponse.DEFAULT_PING_INTERVAL = 1000
|
||||
|
||||
# set LLM path
|
||||
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
|
||||
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
|
||||
|
||||
# set Embedding Model path
|
||||
EMBEDDING_PATH = os.environ.get('EMBEDDING_PATH', 'BAAI/bge-large-zh-v1.5')
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
yield
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "owner"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
permission: Optional[list] = None
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = []
|
||||
|
||||
|
||||
class FunctionCallResponse(BaseModel):
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["user", "assistant", "system", "function"]
|
||||
content: str = None
|
||||
name: Optional[str] = None
|
||||
function_call: Optional[FunctionCallResponse] = None
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||
content: Optional[str] = None
|
||||
function_call: Optional[FunctionCallResponse] = None
|
||||
|
||||
|
||||
## for Embedding
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: List[str]
|
||||
model: str
|
||||
|
||||
|
||||
class CompletionUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
data: list
|
||||
model: str
|
||||
object: str
|
||||
usage: CompletionUsage
|
||||
|
||||
|
||||
# for ChatCompletionRequest
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = 0.8
|
||||
top_p: Optional[float] = 0.8
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
tools: Optional[Union[dict, List[dict]]] = None
|
||||
repetition_penalty: Optional[float] = 1.1
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Literal["stop", "length", "function_call"]
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[Literal["stop", "length", "function_call"]]
|
||||
index: int
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
model: str
|
||||
id: str
|
||||
object: Literal["chat.completion", "chat.completion.chunk"]
|
||||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
|
||||
async def get_embeddings(request: EmbeddingRequest):
|
||||
embeddings = [embedding_model.encode(text) for text in request.input]
|
||||
embeddings = [embedding.tolist() for embedding in embeddings]
|
||||
|
||||
def num_tokens_from_string(string: str) -> int:
|
||||
"""
|
||||
Returns the number of tokens in a text string.
|
||||
use cl100k_base tokenizer
|
||||
"""
|
||||
encoding = tiktoken.get_encoding('cl100k_base')
|
||||
num_tokens = len(encoding.encode(string))
|
||||
return num_tokens
|
||||
|
||||
response = {
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": embedding,
|
||||
"index": index
|
||||
}
|
||||
for index, embedding in enumerate(embeddings)
|
||||
],
|
||||
"model": request.model,
|
||||
"object": "list",
|
||||
"usage": CompletionUsage(
|
||||
prompt_tokens=sum(len(text.split()) for text in request.input),
|
||||
completion_tokens=0,
|
||||
total_tokens=sum(num_tokens_from_string(text) for text in request.input),
|
||||
)
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/v1/models", response_model=ModelList)
|
||||
async def list_models():
|
||||
model_card = ModelCard(
|
||||
id="chatglm3-6b"
|
||||
)
|
||||
return ModelList(
|
||||
data=[model_card]
|
||||
)
|
||||
|
||||
|
||||
def insert_custom_message(messages: List[ChatMessage] = Body(...)):
|
||||
custom_message = ChatMessage(
|
||||
role="system",
|
||||
content="你必须记住,你是一个只负责聊天的人工智能助手,所以你只能根据上下文回答你朋友的问题,要求语句完整、通顺、文采斐然,回答具有情感,不能胡乱作答、不能断章取义、不能上下文不连贯,禁止回答无关内容。"
|
||||
"如果你不能理解你朋友的话,你必须开启一个新的话题来引导你朋友进行对话,例如你可以问你的朋友一个日常问题。\n"
|
||||
"你是司小远,一个聪明、热情、善良的人工智能助手,后面的对话来自你的朋友,你要认真地回答他。"
|
||||
)
|
||||
messages.insert(0, custom_message)
|
||||
return messages
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
global model, tokenizer
|
||||
|
||||
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
messages = request.messages
|
||||
if request.messages and request.messages[0].role == 'system':
|
||||
messages = request.messages
|
||||
else:
|
||||
if request.messages:
|
||||
messages = insert_custom_message(request.messages)
|
||||
else:
|
||||
messages = request.messages
|
||||
print(type(request.messages), request.messages)
|
||||
gen_params = dict(
|
||||
messages=messages,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_tokens=request.max_tokens or 1024,
|
||||
echo=False,
|
||||
stream=request.stream,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
tools=request.tools,
|
||||
)
|
||||
logger.debug(f"==== request ====\n{gen_params}")
|
||||
|
||||
if request.stream:
|
||||
|
||||
# Use the stream mode to read the first few characters, if it is not a function call, direct stram output
|
||||
predict_stream_generator = predict_stream(request.model, gen_params)
|
||||
# return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
|
||||
output = next(predict_stream_generator)
|
||||
print(output)
|
||||
# logger.debug(f"First result output:\n{output}")
|
||||
if not contains_custom_function(output):
|
||||
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
|
||||
|
||||
# Obtain the result directly at one time and determine whether tools needs to be called.
|
||||
# logger.debug(f"First result output:\n{output}")
|
||||
|
||||
function_call = None
|
||||
if output and request.tools:
|
||||
try:
|
||||
function_call = process_response(output, use_tool=True)
|
||||
except:
|
||||
logger.warning("Failed to parse tool call")
|
||||
|
||||
# CallFunction
|
||||
if isinstance(function_call, dict):
|
||||
function_call = FunctionCallResponse(**function_call)
|
||||
|
||||
"""
|
||||
In this demo, we did not register any tools.
|
||||
You can use the tools that have been implemented in our `tools_using_demo` and implement your own streaming tool implementation here.
|
||||
Similar to the following method:
|
||||
function_args = json.loads(function_call.arguments)
|
||||
tool_response = dispatch_tool(tool_name: str, tool_params: dict)
|
||||
"""
|
||||
tool_response = ""
|
||||
|
||||
if not gen_params.get("messages"):
|
||||
gen_params["messages"] = []
|
||||
|
||||
gen_params["messages"].append(ChatMessage(
|
||||
role="assistant",
|
||||
content=output,
|
||||
))
|
||||
gen_params["messages"].append(ChatMessage(
|
||||
role="function",
|
||||
name=function_call.name,
|
||||
content=tool_response,
|
||||
))
|
||||
|
||||
# Streaming output of results after function calls
|
||||
generate = predict(request.model, gen_params)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
else:
|
||||
# Handled to avoid exceptions in the above parsing function process.
|
||||
generate = parse_output_text(request.model, output)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
# Here is the handling of stream = False
|
||||
response = generate_chatglm3(model, tokenizer, gen_params)
|
||||
|
||||
# Remove the first newline character
|
||||
if response["text"].startswith("\n"):
|
||||
response["text"] = response["text"][1:]
|
||||
response["text"] = response["text"].strip()
|
||||
|
||||
usage = UsageInfo()
|
||||
function_call, finish_reason = None, "stop"
|
||||
if request.tools:
|
||||
try:
|
||||
function_call = process_response(response["text"], use_tool=True)
|
||||
except:
|
||||
logger.warning("Failed to parse tool call, maybe the response is not a tool call or have been answered.")
|
||||
|
||||
if isinstance(function_call, dict):
|
||||
finish_reason = "function_call"
|
||||
function_call = FunctionCallResponse(**function_call)
|
||||
|
||||
message = ChatMessage(
|
||||
role="assistant",
|
||||
content=response["text"],
|
||||
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
|
||||
)
|
||||
|
||||
logger.debug(f"==== message ====\n{message}")
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=message,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
task_usage = UsageInfo.model_validate(response["usage"])
|
||||
for usage_key, usage_value in task_usage.model_dump().items():
|
||||
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
model=request.model,
|
||||
id="", # for open_source model, id is empty
|
||||
choices=[choice_data],
|
||||
object="chat.completion",
|
||||
usage=usage
|
||||
)
|
||||
|
||||
|
||||
async def predict(model_id: str, params: dict):
|
||||
global model, tokenizer
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||
|
||||
previous_text = ""
|
||||
for new_response in generate_stream_chatglm3(model, tokenizer, params):
|
||||
decoded_unicode = new_response["text"]
|
||||
delta_text = decoded_unicode[len(previous_text):]
|
||||
previous_text = decoded_unicode
|
||||
|
||||
finish_reason = new_response["finish_reason"]
|
||||
if len(delta_text) == 0 and finish_reason != "function_call":
|
||||
continue
|
||||
|
||||
function_call = None
|
||||
if finish_reason == "function_call":
|
||||
try:
|
||||
function_call = process_response(decoded_unicode, use_tool=True)
|
||||
except:
|
||||
logger.warning(
|
||||
"Failed to parse tool call, maybe the response is not a tool call or have been answered.")
|
||||
|
||||
if isinstance(function_call, dict):
|
||||
function_call = FunctionCallResponse(**function_call)
|
||||
|
||||
delta = DeltaMessage(
|
||||
content=delta_text,
|
||||
role="assistant",
|
||||
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=delta,
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id="",
|
||||
choices=[choice_data],
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(),
|
||||
finish_reason="stop"
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id="",
|
||||
choices=[choice_data],
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||
yield '[DONE]'
|
||||
|
||||
|
||||
def predict_stream(model_id, gen_params):
|
||||
"""
|
||||
The function call is compatible with stream mode output.
|
||||
|
||||
The first seven characters are determined.
|
||||
If not a function call, the stream output is directly generated.
|
||||
Otherwise, the complete character content of the function call is returned.
|
||||
|
||||
:param model_id:
|
||||
:param gen_params:
|
||||
:return:
|
||||
"""
|
||||
output = ""
|
||||
is_function_call = False
|
||||
has_send_first_chunk = False
|
||||
print('参数')
|
||||
print(model_id,gen_params)
|
||||
for new_response in generate_stream_chatglm3(model, tokenizer, gen_params):
|
||||
decoded_unicode = new_response["text"]
|
||||
delta_text = decoded_unicode[len(output):]
|
||||
output = decoded_unicode
|
||||
|
||||
# When it is not a function call and the character length is> 7,
|
||||
# try to judge whether it is a function call according to the special function prefix
|
||||
if not is_function_call:
|
||||
|
||||
# Determine whether a function is called
|
||||
is_function_call = contains_custom_function(output)
|
||||
if is_function_call:
|
||||
continue
|
||||
|
||||
# Non-function call, direct stream output
|
||||
finish_reason = new_response["finish_reason"]
|
||||
|
||||
# Send an empty string first to avoid truncation by subsequent next() operations.
|
||||
if not has_send_first_chunk:
|
||||
message = DeltaMessage(
|
||||
content="",
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=message,
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id="",
|
||||
choices=[choice_data],
|
||||
created=int(time.time()),
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||
|
||||
send_msg = delta_text if has_send_first_chunk else output
|
||||
has_send_first_chunk = True
|
||||
message = DeltaMessage(
|
||||
content=send_msg,
|
||||
role="assistant",
|
||||
function_call=None,
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=message,
|
||||
finish_reason=finish_reason
|
||||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id="",
|
||||
choices=[choice_data],
|
||||
created=int(time.time()),
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||
|
||||
if is_function_call:
|
||||
yield output
|
||||
else:
|
||||
yield '[DONE]'
|
||||
|
||||
|
||||
async def parse_output_text(model_id: str, value: str):
|
||||
"""
|
||||
Directly output the text content of value
|
||||
|
||||
:param model_id:
|
||||
:param value:
|
||||
:return:
|
||||
"""
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant", content=value),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(),
|
||||
finish_reason="stop"
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
|
||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||
yield '[DONE]'
|
||||
|
||||
|
||||
def contains_custom_function(value: str) -> bool:
|
||||
"""
|
||||
Determine whether 'function_call' according to a special function prefix.
|
||||
|
||||
For example, the functions defined in "tools_using_demo/tool_register.py" are all "get_xxx" and start with "get_"
|
||||
|
||||
[Note] This is not a rigorous judgment method, only for reference.
|
||||
|
||||
:param value:
|
||||
:return:
|
||||
"""
|
||||
return value and 'get_' in value
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Union
|
||||
|
||||
import typer
|
||||
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
|
||||
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
|
||||
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||
|
||||
|
||||
def _resolve_path(path: Union[str, Path]) -> Path:
|
||||
return Path(path).expanduser().resolve()
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
model_dir: Union[str, Path], trust_remote_code: bool = True
|
||||
) -> tuple[ModelType, TokenizerType]:
|
||||
model_dir = _resolve_path(model_dir)
|
||||
if (model_dir / 'adapter_config.json').exists():
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
|
||||
)
|
||||
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
|
||||
)
|
||||
tokenizer_dir = model_dir
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_dir, trust_remote_code=trust_remote_code
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Load LLM
|
||||
# tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
|
||||
# model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
|
||||
# 填微调之后的保存路径
|
||||
model, tokenizer = load_model_and_tokenizer(
|
||||
r'E:\Project\Python\ChatGLM3\finetune_demo\output03-24\checkpoint-224000'
|
||||
)
|
||||
# load Embedding
|
||||
embedding_model = SentenceTransformer(EMBEDDING_PATH, device="cuda")
|
||||
uvicorn.run(app, host='0.0.0.0', port=8002, workers=1)
|
||||
BIN
MemoAI/img/img.png
Normal file
BIN
MemoAI/img/img.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 38 KiB |
BIN
MemoAI/img/img2.png
Normal file
BIN
MemoAI/img/img2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
BIN
MemoAI/img/img3.png
Normal file
BIN
MemoAI/img/img3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 359 KiB |
BIN
MemoAI/img/img4.png
Normal file
BIN
MemoAI/img/img4.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 58 KiB |
26
MemoAI/merge_json.py
Normal file
26
MemoAI/merge_json.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
data_dir = r'E:\Project\Python\MemoTrace\data\聊天记录'
|
||||
|
||||
dev_res = []
|
||||
train_res = []
|
||||
|
||||
for filepath, dirnames, filenames in os.walk(data_dir):
|
||||
for filename in filenames:
|
||||
if filename.endswith('.json'):
|
||||
print(filename, filepath)
|
||||
filepath_ = os.path.join(filepath, filename)
|
||||
with open(filepath_, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
if data:
|
||||
if filename.endswith('train.json'):
|
||||
train_res += data
|
||||
else:
|
||||
dev_res += data
|
||||
|
||||
with open('train.json', 'w', encoding='utf-8') as f:
|
||||
json.dump(train_res, f, ensure_ascii=False, indent=4)
|
||||
|
||||
with open('dev.json', 'w', encoding='utf-8') as f:
|
||||
json.dump(dev_res, f, ensure_ascii=False, indent=4)
|
||||
186
MemoAI/qwen2-0.5b/app.py
Normal file
186
MemoAI/qwen2-0.5b/app.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import os
|
||||
import copy
|
||||
import random
|
||||
import threading
|
||||
import subprocess
|
||||
import gradio as gr
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
|
||||
|
||||
os.system("pip uninstall -y tensorflow tensorflow-estimator tensorflow-io-gcs-filesystem")
|
||||
os.environ["LANG"] = "C"
|
||||
os.environ["LC_ALL"] = "C"
|
||||
|
||||
default_system = '你是一个微信聊天机器人'
|
||||
|
||||
from dashinfer.helper import EngineHelper, ConfigManager
|
||||
|
||||
log_lock = threading.Lock()
|
||||
|
||||
config_file = "di_config.json"
|
||||
config = ConfigManager.get_config_from_json(config_file)
|
||||
|
||||
def download_model(model_id, revision, source="modelscope"):
|
||||
print(f"Downloading model {model_id} (revision: {revision}) from {source}")
|
||||
if source == "modelscope":
|
||||
from modelscope import snapshot_download
|
||||
model_dir = snapshot_download(model_id, revision=revision)
|
||||
elif source == "huggingface":
|
||||
from huggingface_hub import snapshot_download
|
||||
model_dir = snapshot_download(repo_id=model_id)
|
||||
else:
|
||||
raise ValueError("Unknown source")
|
||||
|
||||
print(f"Save model to path {model_dir}")
|
||||
|
||||
return model_dir
|
||||
|
||||
cmd = f"pip show dashinfer | grep 'Location' | cut -d ' ' -f 2"
|
||||
package_location = subprocess.run(cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
shell=True,
|
||||
text=True)
|
||||
package_location = package_location.stdout.strip()
|
||||
os.environ["AS_DAEMON_PATH"] = package_location + "/dashinfer/allspark/bin"
|
||||
os.environ["AS_NUMA_NUM"] = str(len(config["device_ids"]))
|
||||
os.environ["AS_NUMA_OFFSET"] = str(config["device_ids"][0])
|
||||
|
||||
## download original model
|
||||
## download model from modelscope
|
||||
original_model = {
|
||||
"source": "modelscope",
|
||||
"model_id": config["model_space"] + config["model_name"],
|
||||
"revision": "master",
|
||||
"model_path": ""
|
||||
}
|
||||
original_model["model_path"] = download_model(original_model["model_id"],
|
||||
original_model["revision"],
|
||||
original_model["source"])
|
||||
|
||||
engine_helper = EngineHelper(config)
|
||||
engine_helper.verbose = True
|
||||
engine_helper.init_tokenizer(original_model["model_path"])
|
||||
|
||||
## convert huggingface model to dashinfer model
|
||||
## only one conversion is required
|
||||
engine_helper.convert_model(original_model["model_path"])
|
||||
|
||||
engine_helper.init_engine()
|
||||
engine_max_batch = engine_helper.engine_config["engine_max_batch"]
|
||||
|
||||
###################################################
|
||||
|
||||
History = List[Tuple[str, str]]
|
||||
Messages = List[Dict[str, str]]
|
||||
|
||||
|
||||
class Role:
|
||||
USER = 'user'
|
||||
SYSTEM = 'system'
|
||||
BOT = 'bot'
|
||||
ASSISTANT = 'assistant'
|
||||
ATTACHMENT = 'attachment'
|
||||
|
||||
|
||||
def clear_session() -> History:
|
||||
return '', []
|
||||
|
||||
|
||||
def modify_system_session(system: str) -> str:
|
||||
if system is None or len(system) == 0:
|
||||
system = default_system
|
||||
return system, system, []
|
||||
|
||||
|
||||
def history_to_messages(history: History, system: str) -> Messages:
|
||||
messages = [{'role': Role.SYSTEM, 'content': system}]
|
||||
for h in history:
|
||||
messages.append({'role': Role.USER, 'content': h[0]})
|
||||
messages.append({'role': Role.ASSISTANT, 'content': h[1]})
|
||||
return messages
|
||||
|
||||
|
||||
def messages_to_history(messages: Messages) -> Tuple[str, History]:
|
||||
assert messages[0]['role'] == Role.SYSTEM
|
||||
system = messages[0]['content']
|
||||
history = []
|
||||
for q, r in zip(messages[1::2], messages[2::2]):
|
||||
history.append([q['content'], r['content']])
|
||||
return system, history
|
||||
|
||||
|
||||
def message_to_prompt(messages: Messages) -> str:
|
||||
prompt = ""
|
||||
for item in messages:
|
||||
im_start, im_end = "<|im_start|>", "<|im_end|>"
|
||||
prompt += f"\n{im_start}{item['role']}\n{item['content']}{im_end}"
|
||||
prompt += f"\n{im_start}assistant\n"
|
||||
return prompt
|
||||
|
||||
|
||||
def model_chat(query: Optional[str], history: Optional[History],
|
||||
system: str) -> Tuple[str, str, History]:
|
||||
if query is None:
|
||||
query = ''
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
messages = history_to_messages(history, system)
|
||||
messages.append({'role': Role.USER, 'content': query})
|
||||
prompt = message_to_prompt(messages)
|
||||
|
||||
gen_cfg = copy.deepcopy(engine_helper.default_gen_cfg)
|
||||
gen_cfg["max_length"] = 1024
|
||||
gen_cfg["seed"] = random.randint(0, 10000)
|
||||
|
||||
request_list = engine_helper.create_request([prompt], [gen_cfg])
|
||||
|
||||
request = request_list[0]
|
||||
gen = engine_helper.process_one_request_stream(request)
|
||||
for response in gen:
|
||||
role = Role.ASSISTANT
|
||||
system, history = messages_to_history(messages + [{'role': role, 'content': response}])
|
||||
yield '', history, system
|
||||
|
||||
json_str = engine_helper.convert_request_to_jsonstr(request)
|
||||
log_lock.acquire()
|
||||
try:
|
||||
print(f"{json_str}\n")
|
||||
finally:
|
||||
log_lock.release()
|
||||
|
||||
###################################################
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
demo_title = "<center>微信的你</center>"
|
||||
gr.Markdown(demo_title)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
system_input = gr.Textbox(value=default_system,
|
||||
lines=1,
|
||||
label='System')
|
||||
with gr.Column(scale=1):
|
||||
modify_system = gr.Button("🛠️ Set system prompt and clear history.", scale=2)
|
||||
system_state = gr.Textbox(value=default_system, visible=False)
|
||||
chatbot = gr.Chatbot(label=config["model_name"])
|
||||
textbox = gr.Textbox(lines=2, label='Input')
|
||||
|
||||
with gr.Row():
|
||||
clear_history = gr.Button("🧹清除历史记录")
|
||||
sumbit = gr.Button("🚀和我聊天!")
|
||||
|
||||
sumbit.click(model_chat,
|
||||
inputs=[textbox, chatbot, system_state],
|
||||
outputs=[textbox, chatbot, system_input],
|
||||
concurrency_limit=engine_max_batch)
|
||||
clear_history.click(fn=clear_session,
|
||||
inputs=[],
|
||||
outputs=[textbox, chatbot],
|
||||
concurrency_limit=engine_max_batch)
|
||||
modify_system.click(fn=modify_system_session,
|
||||
inputs=[system_input],
|
||||
outputs=[system_state, system_input, chatbot],
|
||||
concurrency_limit=engine_max_batch)
|
||||
|
||||
demo.queue(api_open=False).launch(height=800, share=False, server_name="127.0.0.1", server_port=7860)
|
||||
52
MemoAI/qwen2-0.5b/di_config.json
Normal file
52
MemoAI/qwen2-0.5b/di_config.json
Normal file
@@ -0,0 +1,52 @@
|
||||
{
|
||||
"model_space": "YOUR-NAME-SPACE",
|
||||
"model_name": "YOUR-MODEL-NAME",
|
||||
"model_type": "Qwen_v20",
|
||||
"model_path": "./dashinfer_models/",
|
||||
"data_type": "float32",
|
||||
"device_type": "CPU",
|
||||
"device_ids": [
|
||||
0
|
||||
],
|
||||
"multinode_mode": false,
|
||||
"engine_config": {
|
||||
"engine_max_length": 1024,
|
||||
"engine_max_batch": 2,
|
||||
"do_profiling": false,
|
||||
"num_threads": 0,
|
||||
"matmul_precision": "medium"
|
||||
},
|
||||
"generation_config": {
|
||||
"temperature": 0.7,
|
||||
"early_stopping": true,
|
||||
"top_k": 20,
|
||||
"top_p": 0.8,
|
||||
"repetition_penalty": 1.05,
|
||||
"presence_penalty": 0.0,
|
||||
"min_length": 0,
|
||||
"max_length": 8192,
|
||||
"no_repeat_ngram_size": 0,
|
||||
"eos_token_id": 151643,
|
||||
"seed": 1234,
|
||||
"stop_words_ids": [
|
||||
[
|
||||
151643
|
||||
],
|
||||
[
|
||||
151644
|
||||
],
|
||||
[
|
||||
151645
|
||||
]
|
||||
]
|
||||
},
|
||||
"convert_config": {
|
||||
"do_dynamic_quantize_convert": false
|
||||
},
|
||||
"quantization_config": {
|
||||
"activation_type": "bfloat16",
|
||||
"weight_type": "uint8",
|
||||
"SubChannel": true,
|
||||
"GroupSize": 512
|
||||
}
|
||||
}
|
||||
1
MemoAI/qwen2-0.5b/requirements.txt
Normal file
1
MemoAI/qwen2-0.5b/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
dashinfer
|
||||
419
MemoAI/qwen2-0.5b/train.ipynb
Normal file
419
MemoAI/qwen2-0.5b/train.ipynb
Normal file
@@ -0,0 +1,419 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "de53995b-32ed-4722-8cac-ba104c8efacb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 导入环境"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "52fac949-4150-4091-b0c3-2968ab5e385c",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import Dataset\n",
|
||||
"import pandas as pd\n",
|
||||
"from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e098d9eb",
|
||||
"metadata": {
|
||||
"ExecutionIndicator": {
|
||||
"show": true
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = pd.read_json('train.json')\n",
|
||||
"ds = Dataset.from_pandas(df)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8ac92d42-efae-49b1-a00e-ccaa75b98938",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ds[:3]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "380d9f69-9e98-4d2d-b044-1d608a057b0b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 下载模型"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "312d6439-1932-44a3-b592-9adbdb7ab702",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from modelscope import snapshot_download\n",
|
||||
"model_dir = snapshot_download('qwen/Qwen2-0.5B-Instruct', cache_dir='qwen2-0.5b/')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "51d05e5d-d14e-4f03-92be-9a9677d41918",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 处理数据集"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "74ee5a67-2e55-4974-b90e-cbf492de500a",
|
||||
"metadata": {
|
||||
"ExecutionIndicator": {
|
||||
"show": true
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokenizer = AutoTokenizer.from_pretrained('./qwen2-0.5b/qwen/Qwen2-0___5B-Instruct/', use_fast=False, trust_remote_code=True)\n",
|
||||
"tokenizer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2503a5fa-9621-4495-9035-8e7ef6525691",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def process_func(example):\n",
|
||||
" MAX_LENGTH = 384 # Llama分词器会将一个中文字切分为多个token,因此需要放开一些最大长度,保证数据的完整性\n",
|
||||
" input_ids, attention_mask, labels = [], [], []\n",
|
||||
" instruction = tokenizer(f\"<|im_start|>system\\n现在你需要扮演我,和我的微信好友快乐聊天!<|im_end|>\\n<|im_start|>user\\n{example['instruction'] + example['input']}<|im_end|>\\n<|im_start|>assistant\\n\", add_special_tokens=False)\n",
|
||||
" response = tokenizer(f\"{example['output']}\", add_special_tokens=False)\n",
|
||||
" input_ids = instruction[\"input_ids\"] + response[\"input_ids\"] + [tokenizer.pad_token_id]\n",
|
||||
" attention_mask = instruction[\"attention_mask\"] + response[\"attention_mask\"] + [1] # 因为eos token咱们也是要关注的所以 补充为1\n",
|
||||
" labels = [-100] * len(instruction[\"input_ids\"]) + response[\"input_ids\"] + [tokenizer.pad_token_id] \n",
|
||||
" if len(input_ids) > MAX_LENGTH: # 做一个截断\n",
|
||||
" input_ids = input_ids[:MAX_LENGTH]\n",
|
||||
" attention_mask = attention_mask[:MAX_LENGTH]\n",
|
||||
" labels = labels[:MAX_LENGTH]\n",
|
||||
" return {\n",
|
||||
" \"input_ids\": input_ids,\n",
|
||||
" \"attention_mask\": attention_mask,\n",
|
||||
" \"labels\": labels\n",
|
||||
" }"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "84f870d6-73a9-4b0f-8abf-687b32224ad8",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokenized_id = ds.map(process_func, remove_columns=ds.column_names)\n",
|
||||
"tokenized_id"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1f7e15a0-4d9a-4935-9861-00cc472654b1",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokenizer.decode(tokenized_id[0]['input_ids'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "97f16f66-324a-454f-8cc3-ef23b100ecff",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tokenizer.decode(list(filter(lambda x: x != -100, tokenized_id[1][\"labels\"])))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "424823a8-ed0d-4309-83c8-3f6b1cdf274c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 创建模型"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "170764e5-d899-4ef4-8c53-36f6dec0d198",
|
||||
"metadata": {
|
||||
"ExecutionIndicator": {
|
||||
"show": true
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"model = AutoModelForCausalLM.from_pretrained('./qwen2-0.5b/qwen/Qwen2-0___5B-Instruct', device_map=\"auto\",torch_dtype=torch.bfloat16)\n",
|
||||
"model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2323eac7-37d5-4288-8bc5-79fac7113402",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.enable_input_require_grads()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f808b05c-f2cb-48cf-a80d-0c42be6051c7",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.dtype"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "13d71257-3c1c-4303-8ff8-af161ebc2cf1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# lora "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2d304ae2-ab60-4080-a80d-19cac2e3ade3",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from peft import LoraConfig, TaskType, get_peft_model\n",
|
||||
"\n",
|
||||
"config = LoraConfig(\n",
|
||||
" task_type=TaskType.CAUSAL_LM, \n",
|
||||
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
||||
" inference_mode=False, # 训练模式\n",
|
||||
" r=8, # Lora 秩\n",
|
||||
" lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理\n",
|
||||
" lora_dropout=0.1# Dropout 比例\n",
|
||||
")\n",
|
||||
"config"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2c2489c5-eaab-4e1f-b06a-c3f914b4bf8e",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = get_peft_model(model, config)\n",
|
||||
"config"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ebf5482b-fab9-4eb3-ad88-c116def4be12",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.print_trainable_parameters()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ca055683-837f-4865-9c57-9164ba60c00f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 配置训练参数"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7e76bbff-15fd-4995-a61d-8364dc5e9ea0",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"args = TrainingArguments(\n",
|
||||
" output_dir=\"./output/\",\n",
|
||||
" per_device_train_batch_size=4,\n",
|
||||
" gradient_accumulation_steps=4,\n",
|
||||
" logging_steps=10,\n",
|
||||
" num_train_epochs=3,\n",
|
||||
" learning_rate=1e-4,\n",
|
||||
" gradient_checkpointing=True\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f142cb9c-ad99-48e6-ba86-6df198f9ed96",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer = Trainer(\n",
|
||||
" model=model,\n",
|
||||
" args=args,\n",
|
||||
" train_dataset=tokenized_id,\n",
|
||||
" data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "aec9bc36-b297-45af-99e1-d4c4d82be081",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8abb2327-458e-4e96-ac98-2141b5b97c8e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 合并加载模型,这里的路径可能有点不太一样,lora_path填写为Output的最后的checkpoint"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bd2a415a-a9ad-49ea-877f-243558a83bfc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
||||
"import torch\n",
|
||||
"from peft import PeftModel\n",
|
||||
"\n",
|
||||
"mode_path = './qwen2-0.5b/qwen/Qwen2-0___5B-Instruct'\n",
|
||||
"lora_path = './output/checkpoint-10' #修改这里\n",
|
||||
"# 加载tokenizer\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(mode_path, trust_remote_code=True)\n",
|
||||
"\n",
|
||||
"# 加载模型\n",
|
||||
"model = AutoModelForCausalLM.from_pretrained(mode_path, device_map=\"auto\",torch_dtype=torch.bfloat16, trust_remote_code=True).eval()\n",
|
||||
"\n",
|
||||
"# 加载lora权重\n",
|
||||
"model = PeftModel.from_pretrained(model, model_id=lora_path)\n",
|
||||
"\n",
|
||||
"prompt = \"在干啥呢?\"\n",
|
||||
"inputs = tokenizer.apply_chat_template([{\"role\": \"user\", \"content\": \"现在你需要扮演我,和我的微信好友快乐聊天!\"},{\"role\": \"user\", \"content\": prompt}],\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" tokenize=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" return_dict=True\n",
|
||||
" ).to('cuda')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"gen_kwargs = {\"max_length\": 2500, \"do_sample\": True, \"top_k\": 1}\n",
|
||||
"with torch.no_grad():\n",
|
||||
" outputs = model.generate(**inputs, **gen_kwargs)\n",
|
||||
" outputs = outputs[:, inputs['input_ids'].shape[1]:]\n",
|
||||
" print(tokenizer.decode(outputs[0], skip_special_tokens=True))\n",
|
||||
"\n",
|
||||
"# 保存合并后的模型和tokenizer\n",
|
||||
"save_directory = './model_merge'\n",
|
||||
"\n",
|
||||
"# 保存模型\n",
|
||||
"\n",
|
||||
"model.save_pretrained(save_directory)\n",
|
||||
"\n",
|
||||
"# 保存tokenizer\n",
|
||||
"tokenizer.save_pretrained(save_directory)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b67e5e0a-2566-4483-9bce-92b5be8b4b34",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 然后把模型上传到modelscope开始下一步"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dafe4f24-af5c-407e-abbc-eefd9d44cb15",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
199
MemoAI/qwen2-0.5b/train.md
Normal file
199
MemoAI/qwen2-0.5b/train.md
Normal file
@@ -0,0 +1,199 @@
|
||||
# Qwen2-0.B-Instruct 微信AI 微调
|
||||
|
||||
这个教程给大家提供一个 [nodebook](./train.ipynb) 文件,来让大家更好的学习。
|
||||
|
||||
## 模型下载
|
||||
|
||||
使用 modelscope 中的 snapshot_download 函数下载模型,第一个参数为模型名称,参数 cache_dir 为模型的下载路径。
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from modelscope import snapshot_download, AutoModel, AutoTokenizer
|
||||
import os
|
||||
model_dir = snapshot_download('qwen/Qwen2-7B-Instruct', cache_dir='/root/autodl-tmp', revision='master')
|
||||
```
|
||||
|
||||
## 环境配置
|
||||
|
||||
在完成基本环境配置和本地模型部署的情况下,你还需要安装一些第三方库,可以使用以下命令:
|
||||
|
||||
```bash
|
||||
python -m pip install --upgrade pip
|
||||
# 更换 pypi 源加速库的安装
|
||||
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
pip install modelscope==1.9.5
|
||||
pip install "transformers>=4.39.0"
|
||||
pip install streamlit==1.24.0
|
||||
pip install sentencepiece==0.1.99
|
||||
pip install accelerate==0.27
|
||||
pip install transformers_stream_generator==0.0.4
|
||||
pip install datasets==2.18.0
|
||||
pip install peft==0.10.0
|
||||
|
||||
```
|
||||
|
||||
LLM 的微调一般指指令微调过程。所谓指令微调,是说我们使用的微调数据形如:
|
||||
|
||||
```json
|
||||
{
|
||||
"instruction":"以下是你的好友在和你聊天,你需要和他聊天",
|
||||
"input":"吃了吗?",
|
||||
"output":"还在食堂"
|
||||
}
|
||||
```
|
||||
|
||||
其中,`instruction` 是用户指令,告知模型其需要完成的任务;`input` 是用户输入,是完成用户指令所必须的输入内容;`output` 是模型应该给出的输出。
|
||||
|
||||
|
||||
|
||||
|
||||
## 数据格式化
|
||||
|
||||
`Lora` 训练的数据是需要经过格式化、编码之后再输入给模型进行训练的,如果是熟悉 `Pytorch` 模型训练流程的同学会知道,我们一般需要将输入文本编码为 input_ids,将输出文本编码为 `labels`,编码之后的结果都是多维的向量。我们首先定义一个预处理函数,这个函数用于对每一个样本,编码其输入、输出文本并返回一个编码后的字典:
|
||||
|
||||
```python
|
||||
def process_func(example):
|
||||
MAX_LENGTH = 384 # Llama分词器会将一个中文字切分为多个token,因此需要放开一些最大长度,保证数据的完整性
|
||||
input_ids, attention_mask, labels = [], [], []
|
||||
instruction = tokenizer(f"<|im_start|>system\n现在你要扮演皇帝身边的女人--甄嬛<|im_end|>\n<|im_start|>user\n{example['instruction'] + example['input']}<|im_end|>\n<|im_start|>assistant\n", add_special_tokens=False) # add_special_tokens 不在开头加 special_tokens
|
||||
response = tokenizer(f"{example['output']}", add_special_tokens=False)
|
||||
input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
|
||||
attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1] # 因为eos token咱们也是要关注的所以 补充为1
|
||||
labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
|
||||
if len(input_ids) > MAX_LENGTH: # 做一个截断
|
||||
input_ids = input_ids[:MAX_LENGTH]
|
||||
attention_mask = attention_mask[:MAX_LENGTH]
|
||||
labels = labels[:MAX_LENGTH]
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels
|
||||
}
|
||||
```
|
||||
|
||||
`Qwen2` 采用的`Prompt Template`格式如下:
|
||||
|
||||
```text
|
||||
<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
你是谁?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
我是一个有用的助手。<|im_end|>
|
||||
```
|
||||
|
||||
## 加载tokenizer和半精度模型
|
||||
|
||||
模型以半精度形式加载,如果你的显卡比较新的话,可以用`torch.bfolat`形式加载。对于自定义的模型一定要指定`trust_remote_code`参数为`True`。
|
||||
|
||||
```python
|
||||
tokenizer = AutoTokenizer.from_pretrained('./qwen2-0.5b/qwen/Qwen2-0___5B-Instruct/', use_fast=False, trust_remote_code=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained('./qwen2-0.5b/qwen/Qwen2-0___5B-Instruct/', device_map="auto",torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## 定义LoraConfig
|
||||
|
||||
`LoraConfig`这个类中可以设置很多参数,但主要的参数没多少,简单讲一讲,感兴趣的同学可以直接看源码。
|
||||
|
||||
- `task_type`:模型类型
|
||||
- `target_modules`:需要训练的模型层的名字,主要就是`attention`部分的层,不同的模型对应的层的名字不同,可以传入数组,也可以字符串,也可以正则表达式。
|
||||
- `r`:`lora`的秩,具体可以看`Lora`原理
|
||||
- `lora_alpha`:`Lora alaph`,具体作用参见 `Lora` 原理
|
||||
|
||||
`Lora`的缩放是啥嘞?当然不是`r`(秩),这个缩放就是`lora_alpha/r`, 在这个`LoraConfig`中缩放就是4倍。
|
||||
|
||||
```python
|
||||
config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
||||
inference_mode=False, # 训练模式
|
||||
r=8, # Lora 秩
|
||||
lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理
|
||||
lora_dropout=0.1# Dropout 比例
|
||||
)
|
||||
```
|
||||
|
||||
## 自定义 TrainingArguments 参数
|
||||
|
||||
`TrainingArguments`这个类的源码也介绍了每个参数的具体作用,当然大家可以来自行探索,这里就简单说几个常用的。
|
||||
|
||||
- `output_dir`:模型的输出路径
|
||||
- `per_device_train_batch_size`:顾名思义 `batch_size`
|
||||
- `gradient_accumulation_steps`: 梯度累加,如果你的显存比较小,那可以把 `batch_size` 设置小一点,梯度累加增大一些。
|
||||
- `logging_steps`:多少步,输出一次`log`
|
||||
- `num_train_epochs`:顾名思义 `epoch`
|
||||
- `gradient_checkpointing`:梯度检查,这个一旦开启,模型就必须执行`model.enable_input_require_grads()`,这个原理大家可以自行探索,这里就不细说了。
|
||||
|
||||
```python
|
||||
args = TrainingArguments(
|
||||
output_dir="./output",
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
logging_steps=10,
|
||||
num_train_epochs=3,
|
||||
save_steps=100,
|
||||
learning_rate=1e-4,
|
||||
save_on_each_node=True,
|
||||
gradient_checkpointing=True
|
||||
)
|
||||
```
|
||||
|
||||
## 使用 Trainer 训练
|
||||
|
||||
```python
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=args,
|
||||
train_dataset=tokenized_id,
|
||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## 加载 lora 权重推理
|
||||
|
||||
训练好了之后可以使用如下方式加载`lora`权重进行推理:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
|
||||
mode_path = './qwen2-0.5b/qwen/Qwen2-0___5B-Instruct/'
|
||||
lora_path = 'lora_path'
|
||||
|
||||
# 加载tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(mode_path)
|
||||
|
||||
# 加载模型
|
||||
model = AutoModelForCausalLM.from_pretrained(mode_path, device_map="auto",torch_dtype=torch.bfloat16)
|
||||
|
||||
# 加载lora权重
|
||||
model = PeftModel.from_pretrained(model, model_id=lora_path, config=config)
|
||||
|
||||
prompt = "你是谁?"
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to('cuda')
|
||||
|
||||
generated_ids = model.generate(
|
||||
model_inputs.input_ids,
|
||||
max_new_tokens=512
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
450
MemoAI/readme.md
Normal file
450
MemoAI/readme.md
Normal file
@@ -0,0 +1,450 @@
|
||||
# 大模型训练指南
|
||||
|
||||
## 一、导出聊天记录
|
||||
|
||||
导出json格式的聊天记录。
|
||||
|
||||

|
||||
|
||||
如果你想合并多个联系人的数据,可以直接运行下面的代码合并
|
||||
|
||||
```python
|
||||
import json
|
||||
import os
|
||||
|
||||
data_dir = r'E:\Project\Python\MemoTrace\data\聊天记录'
|
||||
|
||||
dev_res = []
|
||||
train_res = []
|
||||
|
||||
for filepath, dirnames, filenames in os.walk(data_dir):
|
||||
for filename in filenames:
|
||||
if filename.endswith('.json'):
|
||||
print(filename, filepath)
|
||||
filepath_ = os.path.join(filepath, filename)
|
||||
with open(filepath_, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
if data:
|
||||
if filename.endswith('train.json'):
|
||||
train_res += data
|
||||
else:
|
||||
dev_res += data
|
||||
|
||||
with open('train.json', 'w', encoding='utf-8') as f:
|
||||
json.dump(train_res, f, ensure_ascii=False, indent=4)
|
||||
|
||||
with open('dev.json', 'w', encoding='utf-8') as f:
|
||||
json.dump(dev_res, f, ensure_ascii=False, indent=4)
|
||||
|
||||
```
|
||||
|
||||
你现在应该有两个文件,dev.json(验证集)和train.json(训练集)
|
||||
|
||||
## 二、下载ChatGLM3-6B模型
|
||||
|
||||
下载地址:[https://github.com/THUDM/ChatGLM3](https://github.com/THUDM/ChatGLM3)
|
||||
|
||||
## 使用方式
|
||||
|
||||
### 环境安装
|
||||
|
||||
首先需要下载本仓库:
|
||||
|
||||
```shell
|
||||
git clone https://github.com/THUDM/ChatGLM3
|
||||
cd ChatGLM3
|
||||
```
|
||||
|
||||
然后使用 pip 安装依赖:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
+ 为了保证 `torch` 的版本正确,请严格按照 [官方文档](https://pytorch.org/get-started/locally/) 的说明安装。
|
||||
+ **如果遇到问题,请参照ChatGLM3项目的解决方案,不要在本项目中提问**
|
||||
|
||||
## 三、ChatGLM3-6B 微调
|
||||
|
||||
本目录提供 ChatGLM3-6B 模型的微调示例,包括全量微调和 P-Tuning v2。格式上,提供多轮对话微调样例和输入输出格式微调样例。
|
||||
|
||||
如果将模型下载到了本地,本文和代码中的 `THUDM/chatglm3-6b` 字段均应替换为相应地址以从本地加载模型。
|
||||
|
||||
运行示例需要 `python>=3.10`,除基础的 `torch` 依赖外,示例代码运行还需要依赖。
|
||||
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 测试硬件标准
|
||||
|
||||
我们仅提供了单机多卡/多机多卡的运行示例,因此您需要至少一台具有多个 GPU 的机器。本仓库中的**默认配置文件**中,我们记录了显存的占用情况:
|
||||
|
||||
+ SFT 全量微调: 4张显卡平均分配,每张显卡占用 `48346MiB` 显存。
|
||||
+ P-TuningV2 微调: 1张显卡,占用 `18426MiB` 显存。
|
||||
+ LORA 微调: 1张显卡,占用 `14082MiB` 显存。
|
||||
|
||||
> 请注意,该结果仅供参考,对于不同的参数,显存占用可能会有所不同。请结合你的硬件情况进行调整。
|
||||
|
||||
> 请注意,我们仅仅使用英伟达 Hopper(代表显卡:H100) 和 Ampère(代表显卡:A100) 架构和系列显卡做过测试。如果您使用其他架构的显卡,可能会出现
|
||||
> 1. 未知的训练问题 / 显存占用与上述有误差。
|
||||
> 2. 架构过低而不支持某些特性。
|
||||
> 3. 推理效果问题。
|
||||
> 以上三种情况为社区曾经遇到过的问题,虽然概率极地,如果您遇到了以上问题,可以尝试在社区中解决。
|
||||
|
||||
## 多轮对话格式
|
||||
|
||||
多轮对话微调示例采用 ChatGLM3 对话格式约定,对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`。
|
||||
|
||||
对于数据文件,样例采用如下格式
|
||||
|
||||
如果您仅希望微调模型的对话能力,而非工具能力,您应该按照以下格式整理数据。
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"conversations": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "<system prompt text>"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<user prompt text>"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "<assistant response text>"
|
||||
},
|
||||
// ... Muti Turn
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<user prompt text>"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "<assistant response text>"
|
||||
}
|
||||
]
|
||||
}
|
||||
// ...
|
||||
]
|
||||
```
|
||||
|
||||
**请注意,这种方法在微调的step较多的情况下会影响到模型的工具调用功能**
|
||||
|
||||
- `system` 角色为可选角色,但若存在 `system` 角色,其必须出现在 `user`
|
||||
角色之前,且一个完整的对话数据(无论单轮或者多轮对话)只能出现一次 `system` 角色。
|
||||
|
||||
## 数据集格式示例
|
||||
|
||||
这里以 AdvertiseGen 数据集为例,
|
||||
您可以从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing)
|
||||
或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载 AdvertiseGen 数据集。
|
||||
将解压后的 AdvertiseGen 目录放到 `data` 目录下并自行转换为如下格式数据集。
|
||||
|
||||
> 请注意,现在的微调代码中加入了验证集,因此,对于一组完整的微调数据集,必须包含训练数据集和验证数据集,测试数据集可以不填写。或者直接用验证数据集代替。
|
||||
|
||||
```
|
||||
{"conversations": [{"role": "user", "content": "类型#裙*裙长#半身裙"}, {"role": "assistant", "content": "这款百搭时尚的仙女半身裙,整体设计非常的飘逸随性,穿上之后每个女孩子都能瞬间变成小仙女啦。料子非常的轻盈,透气性也很好,穿到夏天也很舒适。"}]}
|
||||
```
|
||||
|
||||
## 配置文件
|
||||
|
||||
微调配置文件位于 `config` 目录下,包括以下文件:
|
||||
|
||||
1. `ds_zereo_2 / ds_zereo_3.json`: deepspeed 配置文件。
|
||||
2. `lora.yaml / ptuning.yaml / sft.yaml`: 模型不同方式的配置文件,包括模型参数、优化器参数、训练参数等。 部分重要参数解释如下:
|
||||
+ data_config 部分
|
||||
+ train_file: 训练数据集的文件路径。
|
||||
+ val_file: 验证数据集的文件路径。
|
||||
+ test_file: 测试数据集的文件路径。
|
||||
+ num_proc: 在加载数据时使用的进程数量。
|
||||
+ max_input_length: 输入序列的最大长度。
|
||||
+ max_output_length: 输出序列的最大长度。
|
||||
+ training_args 部分
|
||||
+ output_dir: 用于保存模型和其他输出的目录。
|
||||
+ max_steps: 训练的最大步数。
|
||||
+ per_device_train_batch_size: 每个设备(如 GPU)的训练批次大小。
|
||||
+ dataloader_num_workers: 加载数据时使用的工作线程数量。
|
||||
+ remove_unused_columns: 是否移除数据中未使用的列。
|
||||
+ save_strategy: 模型保存策略(例如,每隔多少步保存一次)。
|
||||
+ save_steps: 每隔多少步保存一次模型。
|
||||
+ log_level: 日志级别(如 info)。
|
||||
+ logging_strategy: 日志记录策略。
|
||||
+ logging_steps: 每隔多少步记录一次日志。
|
||||
+ per_device_eval_batch_size: 每个设备的评估批次大小。
|
||||
+ evaluation_strategy: 评估策略(例如,每隔多少步进行一次评估)。
|
||||
+ eval_steps: 每隔多少步进行一次评估。
|
||||
+ predict_with_generate: 是否使用生成模式进行预测。
|
||||
+ generation_config 部分
|
||||
+ max_new_tokens: 生成的最大新 token 数量。
|
||||
+ peft_config 部分
|
||||
+ peft_type: 使用的参数有效调整类型(如 LORA)。
|
||||
+ task_type: 任务类型,这里是因果语言模型(CAUSAL_LM)。
|
||||
+ Lora 参数:
|
||||
+ r: LoRA 的秩。
|
||||
+ lora_alpha: LoRA 的缩放因子。
|
||||
+ lora_dropout: 在 LoRA 层使用的 dropout 概率
|
||||
+ P-TuningV2 参数:
|
||||
+ num_virtual_tokens: 虚拟 token 的数量。
|
||||
|
||||
## 开始微调
|
||||
|
||||
通过以下代码执行 **单机多卡/多机多卡** 运行,这是使用 `deepspeed` 作为加速方案的,您需要安装 `deepspeed`。
|
||||
|
||||
```angular2html
|
||||
cd finetune_demo
|
||||
OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_hf.py data/AdvertiseGen/ THUDM/chatglm3-6b configs/lora.yaml configs/ds_zero_2.json
|
||||
```
|
||||
|
||||
通过以下代码执行 **单机单卡** 运行。
|
||||
|
||||
```angular2html
|
||||
cd finetune_demo
|
||||
python finetune_hf.py data/AdvertiseGen/ THUDM/chatglm3-6b configs/lora.yaml
|
||||
```
|
||||
|
||||
## 从保存点进行微调
|
||||
|
||||
如果按照上述方式进行训练,每次微调都会从头开始,如果你想从训练一半的模型开始微调,你可以加入第四个参数,这个参数有两种传入方式:
|
||||
|
||||
1. `yes`, 自动从最后一个保存的 Checkpoint开始训练
|
||||
2. `XX`, 断点号数字 例 `600` 则从序号600 Checkpoint开始训练
|
||||
|
||||
例如,这就是一个从最后一个保存点继续微调的示例代码
|
||||
|
||||
```angular2html
|
||||
cd finetune_demo
|
||||
python finetune_hf.py data/AdvertiseGen/ THUDM/chatglm3-6b configs/lora.yaml yes
|
||||
```
|
||||
|
||||
## 使用微调后的模型
|
||||
|
||||
### 在 inference_hf.py 中验证微调后的模型
|
||||
|
||||
您可以在 `finetune_demo/inference_hf.py` 中使用我们的微调后的模型,仅需要一行代码就能简单的进行测试。
|
||||
|
||||
```angular2html
|
||||
python inference_hf.py your_finetune_path --prompt your prompt
|
||||
```
|
||||
|
||||
这样,得到的回答就微调后的回答了。
|
||||
|
||||
### 在本仓库的其他 demo 或者外部仓库使用微调后的模型
|
||||
|
||||
您可以在任何一个 demo 内使用我们的 `lora` 和 全参微调的模型。这需要你自己按照以下教程进行修改代码。
|
||||
|
||||
1. 使用`finetune_demo/inference_hf.py`中读入模型的方式替换 demo 中读入模型的方式。
|
||||
|
||||
> 请注意,对于 LORA 和 P-TuningV2 我们没有合并训练后的模型,而是在`adapter_config.json`
|
||||
> 中记录了微调型的路径,如果你的原始模型位置发生更改,则你应该修改`adapter_config.json`中`base_model_name_or_path`的路径。
|
||||
|
||||
```python
|
||||
def load_model_and_tokenizer(
|
||||
model_dir: Union[str, Path], trust_remote_code: bool = True
|
||||
) -> tuple[ModelType, TokenizerType]:
|
||||
model_dir = _resolve_path(model_dir)
|
||||
if (model_dir / 'adapter_config.json').exists():
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
|
||||
)
|
||||
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
|
||||
)
|
||||
tokenizer_dir = model_dir
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_dir, trust_remote_code=trust_remote_code
|
||||
)
|
||||
return model, tokenizer
|
||||
```
|
||||
|
||||
2. 读取微调的模型,请注意,你应该使用微调模型的位置,例如,若你的模型位置为`/path/to/finetune_adapter_model`
|
||||
,原始模型地址为`path/to/base_model`,则你应该使用`/path/to/finetune_adapter_model`作为`model_dir`。
|
||||
3. 完成上述操作后,就能正常使用微调的模型了,其他的调用方式没有变化。
|
||||
|
||||
### 提示
|
||||
|
||||
1. 微调代码在开始训练前,会先打印首条训练数据的预处理信息(默认已经注释,可以解除注释),显示为
|
||||
|
||||
```log
|
||||
Sanity
|
||||
Check >> >> >> >> >> >> >
|
||||
'[gMASK]': 64790 -> -100
|
||||
'sop': 64792 -> -100
|
||||
'<|system|>': 64794 -> -100
|
||||
'': 30910 -> -100
|
||||
'\n': 13 -> -100
|
||||
'Answer': 20115 -> -100
|
||||
'the': 267 -> -100
|
||||
'following': 1762 -> -100
|
||||
...
|
||||
'know': 683 -> -100
|
||||
'the': 267 -> -100
|
||||
'response': 3010 -> -100
|
||||
'details': 3296 -> -100
|
||||
'.': 30930 -> -100
|
||||
'<|assistant|>': 64796 -> -100
|
||||
'': 30910 -> 30910
|
||||
'\n': 13 -> 13
|
||||
'I': 307 -> 307
|
||||
'need': 720 -> 720
|
||||
'to': 289 -> 289
|
||||
'use': 792 -> 792
|
||||
...
|
||||
<< << << << << << < Sanity
|
||||
Check
|
||||
```
|
||||
|
||||
字样,每行依次表示一个 detokenized string, token_id 和 target_id。其中,`target_id`为`token_id`在模型词表中的索引,`-100`表示该
|
||||
token 不参与 `loss` 计算。
|
||||
|
||||
2. `_prepare_model_for_training` 的作用是遍历模型的所有可训练参数,并确保它们的数据类型为`torch.float32`。
|
||||
这在某些情况下是必要的,因为混合精度训练或其他操作可能会更改模型参数的数据类型。该代码默打开,可以注释,但是如果使用
|
||||
`half` 格式训练出现问题,可以切换回这个代码,显存可能增加。
|
||||
3. 在我们的[Huggingface模型代码](https://huggingface.co/THUDM/chatglm3-6b/blob/main/modeling_chatglm.py)中,有以下内容:
|
||||
```python
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_ret = torch.utils.checkpoint.checkpoint(
|
||||
layer,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
rotary_pos_emb,
|
||||
kv_caches[index],
|
||||
use_cache,
|
||||
use_reentrant=False
|
||||
)
|
||||
```
|
||||
这可能导致训练的时候显存增加,因此,如果您的显存不足,可以尝试将``` use_reentrant``` 修改为`True`。
|
||||
4. 微调后的模型可以使用任何支持 `peft` 载入的模型加速框架,在这里,我们没有提供demo。
|
||||
5. 本仓库的微调数据集格式与 API 微调数据集格式有一定区别
|
||||
+ ZhipuAI API 微调数据集中的 `messages` 字段在本仓库为 `conversation` 字段。
|
||||
+ ZhipuAI API 中的微调文件为 `jsonl`, 在本仓库,需要简单的将文件名改为 `json`。
|
||||
|
||||
> 以上内容来自ChatGLM3项目
|
||||
|
||||
## 微调示例
|
||||
|
||||
配置文件
|
||||
|
||||
config/lora.yaml
|
||||
|
||||
```yaml
|
||||
data_config:
|
||||
train_file: train.json
|
||||
val_file: dev.json
|
||||
test_file: dev.json
|
||||
num_proc: 10
|
||||
max_input_length: 512
|
||||
max_output_length: 128
|
||||
training_args:
|
||||
# see `transformers.Seq2SeqTrainingArguments`
|
||||
output_dir: ./output03-24
|
||||
max_steps: 100000
|
||||
# settings for data loading
|
||||
per_device_train_batch_size: 4
|
||||
dataloader_num_workers: 10
|
||||
remove_unused_columns: false
|
||||
# settings for saving checkpoints
|
||||
save_strategy: steps
|
||||
save_steps: 2000
|
||||
# settings for logging
|
||||
log_level: info
|
||||
logging_strategy: steps
|
||||
logging_steps: 10
|
||||
# settings for evaluation
|
||||
per_device_eval_batch_size: 4
|
||||
evaluation_strategy: steps
|
||||
eval_steps: 5200
|
||||
# settings for optimizer
|
||||
# adam_epsilon: 1e-6
|
||||
# uncomment the following line to detect nan or inf values
|
||||
# debug: underflow_overflow
|
||||
predict_with_generate: yes
|
||||
# see `transformers.GenerationConfig`
|
||||
generation_config:
|
||||
max_new_tokens: 256
|
||||
# set your absolute deepspeed path here
|
||||
#deepspeed: ds_zero_2.json
|
||||
# set to true if train with cpu.
|
||||
use_cpu: false
|
||||
peft_config:
|
||||
peft_type: LORA
|
||||
task_type: CAUSAL_LM
|
||||
r: 8
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.1
|
||||
```
|
||||
|
||||
硬件配置:4090 24G、64G内存、CPU 14700KF 20核28线程
|
||||
|
||||
你需要根据你的硬件配置修改上述参数,各个参数含义上面有说
|
||||
|
||||
微调命令(需要指定数据集路径和ChatGLM3基础大模型的路径)
|
||||
|
||||
```shell
|
||||
python finetune_hf.py data/ E:\\Project\\Python\\Langchain-Chatchat\\chatglm3-6b configs/lora.yaml yes
|
||||
```
|
||||
|
||||
## 部署
|
||||
|
||||
在api_server.py修改微调保存路径
|
||||
```python
|
||||
model, tokenizer = load_model_and_tokenizer(
|
||||
r'E:\Project\Python\ChatGLM3\finetune_demo\output03-24\checkpoint-224000'
|
||||
)
|
||||
```
|
||||
|
||||
直接运行即可
|
||||
|
||||
```shell
|
||||
python api_server.py
|
||||
```
|
||||
|
||||
调用示例(你可以在任意一个支持ChatGPT的应用中使用它):
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
base_url = "http://127.0.0.1:8002/v1/"
|
||||
client = OpenAI(api_key="EMPTY", base_url=base_url)
|
||||
|
||||
def simple_chat(use_stream=True):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好啊"
|
||||
}
|
||||
]
|
||||
response = client.chat.completions.create(
|
||||
model="chatglm3-6b",
|
||||
messages=messages,
|
||||
stream=use_stream,
|
||||
max_tokens=256,
|
||||
temperature=0.8,
|
||||
presence_penalty=1.1,
|
||||
top_p=0.8)
|
||||
if response:
|
||||
if use_stream:
|
||||
for chunk in response:
|
||||
print(chunk.choices[0].delta.content, end='')
|
||||
else:
|
||||
content = response.choices[0].message.content
|
||||
print(content)
|
||||
else:
|
||||
print("Error:", response.status_code)
|
||||
|
||||
if __name__ == "__main__":
|
||||
simple_chat(use_stream=True)
|
||||
```
|
||||
|
||||
## 体验地址
|
||||
|
||||
[https://chat.memotrace.cn/](https://chat.memotrace.cn/)
|
||||
|
||||

|
||||
|
||||

|
||||
Reference in New Issue
Block a user