torchrun  --nproc_per_node $nproc_per_node  --nnodes $nnodes  --node_rank $node_rank  --master_addr $master_addr  --master_port $master_port /opt/aps/workdir/examples/llama_sft/llama_sft.py \
    --model_name_or_path model://6c939cbd-9f14-4ffe-9874-713920298d3d \
    --model_type llama \
    --tokenizer_name model://6c939cbd-9f14-4ffe-9874-713920298d3d \
    --input_dataset_path /opt/aps/workdir/examples/llama_sft/sft_data.json \
    --deepspeed "/opt/aps/workdir/examples/llama_sft/ds_config.json" \
    --overwrite_output_dir \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --do_train \
    --do_eval \
    --logging_strategy steps \
    --logging_steps 1 \
    --output_dir "tmp" \
    --num_train_epochs 1

# ============= or with config overrides ===========
# torchrun  --nproc_per_node $nproc_per_node  --nnodes $nnodes  --node_rank $node_rank  --master_addr $master_addr  --master_port $master_port examples/llama_sft/llama_sft.py \
#     --model_type llama \
#     --tokenizer_name model://6c939cbd-9f14-4ffe-9874-713920298d3d \
#     --input_dataset_path examples/llama_sft/sft_data.json \
#     --overwrite_output_dir \
#     --per_device_train_batch_size 1 \
#     --per_device_eval_batch_size 1 \
#     --do_train \
#     --do_eval \
#     --logging_strategy steps \
#     --logging_steps 1 \
#     --output_dir "tmp" \
#     --num_train_epochs 1 \
#     --config_overrides "num_attention_heads=1,num_hidden_layers=1,num_key_value_heads=1"

if [ "${node_rank}" -eq 0 ]; then
    echo ${node_rank}
    python /opt/aps/workdir/examples/llama_sft/merge_lora_weight.py \
        --model_name_or_path "model://6c939cbd-9f14-4ffe-9874-713920298d3d" \
        --lora_dir "tmp"
        
fi
