from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TrainerCallback, TrainingArguments, TrainerCallback, TrainerState, TrainerControl
import torch
import argparse
import json
from datacanvas.aps import dc
from peft import (
    LoraConfig,
    get_peft_model,
    PeftModel,
)
import os
import shutil


os.environ["WANDB_DISABLED"] = "true"

parser = argparse.ArgumentParser(description="Run alaya model sft")

parser.add_argument(
        "--model_name_or_path", type=str,
        help="model_name_or_path"
    )
    
parser.add_argument(
        "--lora_dir", type=str,
        help="lora_dir"
    )
    
args = parser.parse_args()

from datacanvas.model import model_repo_client
model_path = model_repo_client.get_model_path(args.model_name_or_path)
args.model_name_or_path=model_path
model_dir = args.model_name_or_path
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(model_dir, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
peft_path = str(dc.output.model_dir)+"/"+str(args.lora_dir)

model = PeftModel.from_pretrained(base_model, peft_path)
model = model.merge_and_unload()
tokenizer.save_pretrained(dc.output.model_dir)
model.save_pretrained(dc.output.model_dir,safe_serialization=False)
print('merged lora weight')

#删除lora模型
# try:
#     shutil.rmtree(peft_path)
#     print("文件夹已删除")
# except OSError as e:
#     print(f"Error: {e.strerror}")

path = str(dc.output.model_dir) + '/config.json'
with open(path, 'r') as f:
    data = json.load(f)

if '_name_or_path' in data:
    del data['_name_or_path']

with open(path, 'w') as file:
    json.dump(data, file)

print("model_dir",dc.output.model_dir)

print('Done')
