import ray
from ray.util import ActorPool

@ray.remote(num_gpus = 1)
def start_torchrun(rank, nproc_per_node=1, num_nodes=1, master_addr="127.0.0.1", master_port=29500):
    import os
    import subprocess
    import socket

    # 获取主机名
    hostname = socket.gethostname()

    # 获取主机的 IP 地址
    ip_address = socket.gethostbyname(hostname)
    # os.environ["master_addr"] = master_addr
    # os.environ["master_port"] = str(master_port)
    # os.environ["nnodes"] = str(num_nodes)
    # os.environ["nproc_per_node"] = "1"

    root_dir = '/workspace'
    print(f"Node {rank} is running on {hostname}, ip address: {ip_address}")

    cmd = [
        "torchrun",
        "--nproc_per_node", str(nproc_per_node),
        "--nnodes", str(num_nodes),
        "--node_rank", str(rank),
        "--master_addr", master_addr,
        "--master_port", str(master_port),
        f"{root_dir}/llama_pretrain/llama_pretrain.py",
        "--model_type", "llama",
        "--config_overrides", "num_attention_heads=32,num_hidden_layers=1,num_key_value_heads=1",
        "--tokenizer_name", f"{root_dir}/llama_tokenizer",
        "--train_file", f"{root_dir}/llama_pretrain/data/pretrain_data.txt",
        "--per_device_train_batch_size", "1",
        "--per_device_eval_batch_size", "1",
        "--bf16", "True",
        "--overwrite_output_dir",
        "--do_train",
        "--do_eval",
        "--logging_strategy", "steps",
        "--logging_steps", "10",
        "--output_dir", f"{root_dir}/llama_pretrain/tmp",
        "--save_strategy", "no",
        "--num_train_epochs", "1",
    ]

    print(f"Running command: {cmd}")

    # 定义日志文件路径
    out_log_path = f"{root_dir}/logs/node_{rank}/out.log"
    err_log_path = f"{root_dir}/logs/node_{rank}/err.log"

    # 确保目录存在
    os.makedirs(os.path.dirname(out_log_path), exist_ok=True)
    os.makedirs(os.path.dirname(err_log_path), exist_ok=True)

    # 打开日志文件
    with open(out_log_path, "w") as out_file, open(err_log_path, "w") as err_file:
        # 运行子进程并重定向输出
        subprocess.run(
            cmd,
            stdout=out_file,  # 重定向标准输出到 out.log
            stderr=err_file   # 重定向标准错误到 err.log
        )

    return rank

if __name__ == '__main__':

    import os
    ray.init(address="auto")
    master_addr = ray._private.services.get_node_ip_address()
    master_port = 29500  # 固定端口
    num_nodes = 3  # 节点数量
    nproc_per_node = 1  # 每个节点的进程数
    print(f"Master Address (from Ray): {master_addr}")
    tasks = [
        start_torchrun.remote(rank, nproc_per_node, num_nodes, master_addr, master_port)
        for rank in range(num_nodes)
    ]
    # 收集结果
    results = ray.get(tasks)
    for i, rank in enumerate(results):
        print(f"Node {rank} complete successfully")

