创建分布式训练任务
创建一个分布式训练任务。该接口为异步接口:调用成功后返回一个订阅结果标识 resultKey,需再通过查询训练任务订阅结果轮询,拿到真正的训练任务 ID 与创建结果。
https://api.alayanew.com/v1/subscription/training/create鉴权(Authorizations)
bearerAuthAuthorizationString必填用户可通过已获取的 Open API Key 做验证。例如:Bearer [YOUR_API_KEY]。
Request body
application/jsonnameString必填训练任务名称。例如:llama3-8b-sft。
createTypeString必填创建方式:DIRECT 直接创建、COPY 复制已有任务创建、TEMPLATE 由模版创建。常规创建填 DIRECT。
descString训练任务描述。例如:Llama3-8B 指令微调。
priorityInteger任务优先级,取值 1~3,数值越小优先级越高,默认 3。资源紧张时高优先级任务优先调度。
trainingTypeString训练类型:pre_training 预训练、HPC 高性能计算。
imageTypeString镜像类型:general 基础镜像、application 应用镜像、private 私有镜像。需与 image 对应。
trainingFramworkString训练框架:pytorch、deepspeed、mpi、tensorflow。
productCodeString必填资源配置项的产品编码,决定单节点的 CPU / 内存 / GPU 规格。可从资源配置列表获取。例如:PRD-TRAIN-1。
workerCountInteger必填节点数量(多机多卡的机器数)。单机填 1,4 机填 4。
imageString必填容器镜像地址。例如:harbor.zetyun.cn/anc-public/general/pytorch:2.3.1-gpu。
storageConfigsArray存储挂载配置,可挂载多个 NAS 目录到容器内。
显示 properties
storageIdString存储 ID。例如:72a2a885-e45e-4c79-aaf3-e1fa05abdb92。
storageTypeString存储类型。例如:nas-capacity。
fileDirectoryString存储内的源目录。例如:datasets。
mountPathString容器内挂载路径。例如:/root/nas/。
onlyReadBoolean访问权限:true 只读、false 读写,默认 true。
envObject环境变量(键值对)
显示 properties
Additional propertiesString额外参数,例如:"CUDA_VISIBLE_DEVICES":"0","CUDA_VISIBLE_MASTER":"yu"。
enableAutoRetryBoolean是否开启失败自动重试。例如:true。
maxRetryCountInteger最大重试次数,enableAutoRetry 为 true 时生效。例如:3。
enableTimeoutCancelBoolean是否开启超时自动取消。例如:true。
timeoutHoursInteger超时时间(小时),enableTimeoutCancel 为 true 时生效。例如:24。
startCommandString容器启动命令(训练入口)。例如:python3 -m torch.distributed.run train.py --data /root/nas/datasets。
Response
application/json · 200statusInteger业务状态码,200 表示请求已受理。
messageString接口响应信息。例如:"OK"。
dataString订阅结果标识 resultKey。用它调用查询训练任务订阅结果轮询最终创建结果。例如:"a1b2c3d4-0000-1111-2222-333344445555"。
curl -X 'POST' \
'https://api.alayanew.com/v1/subscription/training/create' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer [YOUR_API_KEY]' \
-d '{
"name": "my training task",
"createType": "DIRECT",
"desc": "my training task",
"priority": 3,
"trainingType": "PRE_TRAINING",
"imageType": "general",
"trainingFramework": "PyTorch",
"productCode": "PRD-QTT",
"workerCount": 1,
"image": "harbor.unite.zetyun.cn/alayanew-87cbf7ca-f126-4194-b137-8fde5adcd09b/fq-torch-1120:latest",
"storageConfigs": [
{
"storageId": "72a2a885-e45e-4c79-aaf3-e1fa05abdb92",
"storageType": "nas-capacity",
"fileDirectory": "nas123",
"mountPath": "/root/nas/",
"onlyRead": true
}
],
"env": {
"CUDA_VISIBLE_DEVICES": "0",
"CUDA_VISIBLE_MASTER": "yu"
},
"enableAutoRetry": true,
"maxRetryCount": 3,
"enableTimeoutCancel": true,
"timeoutHours": 1,
"startCommand": "cat > /tmp/mnist.py << 'EOF'\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nimport os\n\ndef setup_device():\n if torch.cuda.is_available():\n local_rank = int(os.environ.get(\"LOCAL_RANK\", 0))\n device = torch.device(f\"cuda:{local_rank}\")\n torch.cuda.set_device(device)\n backend = \"nccl\"\n else:\n device = torch.device(\"cpu\")\n backend = \"gloo\"\n return device, backend\n\ndef main():\n device, backend = setup_device()\n dist.init_process_group(backend=backend)\n rank = dist.get_rank()\n world_size = dist.get_world_size()\n\n model = nn.Linear(10, 10).to(device)\n\n if backend == \"nccl\":\n model = DDP(model, device_ids=[device.index])\n else:\n model = DDP(model)\n\n inputs = torch.randn(20, 10).to(device)\n labels = torch.randn(20, 10).to(device)\n\n optimizer = optim.SGD(model.parameters(), lr=0.01)\n loss_fn = nn.MSELoss()\n\n for epoch in range(10):\n optimizer.zero_grad()\n outputs = model(inputs)\n loss = loss_fn(outputs, labels)\n loss.backward()\n optimizer.step()\n\n if rank == 0:\n print(f\"Epoch {epoch}, Loss: {loss.item():.4f}, World Size: {world_size}\")\n\n dist.destroy_process_group()\n\nif __name__ == \"__main__\":\n main()\nEOF\npython3 /tmp/mnist.py"
}'import requests
url = "https://api.alayanew.com/v1/subscription/training/create"
headers = {
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": "Bearer [YOUR_API_KEY]"
}
payload = {
"name": "llama3-8b-sft",
"createType": "DIRECT",
"desc": "Llama3-8B 指令微调",
"priority": 3,
"trainingType": "pre_training",
"imageType": "general",
"trainingFramwork": "pytorch",
"productCode": "PRD-TRAIN-1",
"workerCount": 4,
"image": "harbor.zetyun.cn/anc-public/general/pytorch:2.3.1-gpu",
"storageConfigs": [
{
"storageId": "72a2a885-e45e-4c79-aaf3-e1fa05abdb92",
"storageType": "nas-capacity",
"fileDirectory": "datasets",
"mountPath": "/root/nas/",
"onlyRead": False
}
],
"env": {"CUDA_VISIBLE_DEVICES": "0,1,2,3"},
"enableAutoRetry": True,
"maxRetryCount": 3,
"enableTimeoutCancel": True,
"timeoutHours": 24,
"startCommand": "python3 -m torch.distributed.run train.py --data /root/nas/datasets"
}
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
print(response.json())const payload = {
name: 'llama3-8b-sft',
createType: 'DIRECT',
desc: 'Llama3-8B 指令微调',
priority: 3,
trainingType: 'pre_training',
imageType: 'general',
trainingFramwork: 'pytorch',
productCode: 'PRD-TRAIN-1',
workerCount: 4,
image: 'harbor.zetyun.cn/anc-public/general/pytorch:2.3.1-gpu',
storageConfigs: [
{
storageId: '72a2a885-e45e-4c79-aaf3-e1fa05abdb92',
storageType: 'nas-capacity',
fileDirectory: 'datasets',
mountPath: '/root/nas/',
onlyRead: false
}
],
env: { CUDA_VISIBLE_DEVICES: '0,1,2,3' },
enableAutoRetry: true,
maxRetryCount: 3,
enableTimeoutCancel: true,
timeoutHours: 24,
startCommand: 'python3 -m torch.distributed.run train.py --data /root/nas/datasets'
};
fetch('https://api.alayanew.com/v1/subscription/training/create', {
method: 'POST',
headers: {
'accept': 'application/json',
'Content-Type': 'application/json',
'Authorization': 'Bearer [YOUR_API_KEY]'
},
body: JSON.stringify(payload)
})
.then(res => {
if (!res.ok) {
throw new Error(`HTTP error! status: ${res.status}`);
}
return res.json();
})
.then(console.log)
.catch(console.error);{
"status": 200,
"message": "OK",
"data": "a1b2c3d4-0000-1111-2222-333344445555"
}{
"status": 403,
"message": "Forbidden",
"data": {}
}{
"status": 500,
"message": "Internal Server Error",
"data": {}
}最后更新于
