diff --git a/bin/cml/runner.js b/bin/cml/runner.js index 3914de905..91df42088 100755 --- a/bin/cml/runner.js +++ b/bin/cml/runner.js @@ -118,6 +118,7 @@ const runCloud = async (opts) => { cloud, cloudRegion: region, cloudType: type, + cloudMetadata: metadata, cloudGpu: gpu, cloudHddSize: hddSize, cloudSshPrivate: sshPrivate, @@ -151,6 +152,7 @@ const runCloud = async (opts) => { cloud, region, type, + metadata, gpu: gpu === 'tesla' ? 'v100' : gpu, hddSize, sshPrivate, @@ -190,6 +192,7 @@ const runCloud = async (opts) => { instanceLaunchTime: attributes.instance_launch_time, instanceType: attributes.instance_type, labels: attributes.labels, + metadata: attributes.metadata, name: attributes.name, region: attributes.region, repo: attributes.repo, @@ -467,6 +470,19 @@ exports.builder = (yargs) => description: 'Instance type. Choices: [m, l, xl]. Also supports native types like i.e. t2.micro' }, + cloudMetadata: { + type: 'array', + string: true, + coerce: (items) => { + const keyValuePairs = items.map((item) => [ + ...item.split(/=(.+)/), + null + ]); + return Object.fromEntries(keyValuePairs); + }, + description: + 'Key Value pairs to associate cml-runner instance on the provider i.e. tags/labels "key=value"' + }, cloudGpu: { type: 'string', choices: ['nogpu', 'k80', 'v100', 'tesla'], diff --git a/bin/cml/runner.test.js b/bin/cml/runner.test.js index b34e58f68..dee59c673 100644 --- a/bin/cml/runner.test.js +++ b/bin/cml/runner.test.js @@ -42,6 +42,9 @@ Options: [string] [default: \\"us-west\\"] --cloud-type Instance type. Choices: [m, l, xl]. Also supports native types like i.e. t2.micro [string] + --cloud-metadata Key Value pairs to associate cml-runner instance + on the provider i.e. tags/labels \\"key=value\\" + [array] --cloud-gpu GPU type. [string] [choices: \\"nogpu\\", \\"k80\\", \\"v100\\", \\"tesla\\"] --cloud-hdd-size HDD size in GB [number] diff --git a/src/terraform.js b/src/terraform.js index 38ad58a7d..91fdb160c 100644 --- a/src/terraform.js +++ b/src/terraform.js @@ -48,6 +48,9 @@ const destroy = async (opts = {}) => { ); }; +const mapCloudMetadata = (metadata) => + Object.entries(metadata).map(([key, value]) => `${key} = "${value || ''}"`); + const iterativeProviderTpl = () => { return ` terraform { @@ -74,6 +77,7 @@ const iterativeCmlRunnerTpl = (opts = {}) => { name, single, type, + metadata, gpu, hddSize, sshPrivate, @@ -83,7 +87,7 @@ const iterativeCmlRunnerTpl = (opts = {}) => { awsSecurityGroup } = opts; - return ` + const template = ` ${iterativeProviderTpl()} resource "iterative_cml_runner" "runner" { @@ -108,8 +112,14 @@ resource "iterative_cml_runner" "runner" { ${spotPrice ? `spot_price = ${spotPrice}` : ''} ${startupScript ? `startup_script = "${startupScript}"` : ''} ${awsSecurityGroup ? `aws_security_group = "${awsSecurityGroup}"` : ''} + ${ + metadata + ? `metadata = {\n ${mapCloudMetadata(metadata).join('\n ')}\n }` + : '' + } } `; + return template; }; const checkMinVersion = async () => { diff --git a/src/terraform.test.js b/src/terraform.test.js index 9a6f5c935..3fb7a4754 100644 --- a/src/terraform.test.js +++ b/src/terraform.test.js @@ -35,6 +35,7 @@ describe('Terraform tests', () => { + } " `); @@ -60,40 +61,41 @@ describe('Terraform tests', () => { awsSecurityGroup: 'mysg' }); expect(output).toMatchInlineSnapshot(` - " +" - terraform { - required_providers { - iterative = { - source = \\"iterative/iterative\\" - } - } - } +terraform { + required_providers { + iterative = { + source = \\"iterative/iterative\\" + } + } +} - provider \\"iterative\\" {} +provider \\"iterative\\" {} - resource \\"iterative_cml_runner\\" \\"runner\\" { - repo = \\"https://\\" - token = \\"abc\\" - driver = \\"gitlab\\" - labels = \\"mylabel\\" - idle_timeout = 300 - name = \\"myrunner\\" - single = \\"true\\" - cloud = \\"aws\\" - region = \\"west\\" - instance_type = \\"mymachinetype\\" - instance_gpu = \\"mygputype\\" - instance_hdd_size = 50 - ssh_private = \\"myprivate\\" - spot = true - spot_price = 0.0001 - - aws_security_group = \\"mysg\\" - } - " - `); +resource \\"iterative_cml_runner\\" \\"runner\\" { + repo = \\"https://\\" + token = \\"abc\\" + driver = \\"gitlab\\" + labels = \\"mylabel\\" + idle_timeout = 300 + name = \\"myrunner\\" + single = \\"true\\" + cloud = \\"aws\\" + region = \\"west\\" + instance_type = \\"mymachinetype\\" + instance_gpu = \\"mygputype\\" + instance_hdd_size = 50 + ssh_private = \\"myprivate\\" + spot = true + spot_price = 0.0001 + + aws_security_group = \\"mysg\\" + +} +" +`); }); test('basic settings with runner forever', async () => { @@ -115,40 +117,101 @@ describe('Terraform tests', () => { spotPrice: '0.0001' }); expect(output).toMatchInlineSnapshot(` - " +" - terraform { - required_providers { - iterative = { - source = \\"iterative/iterative\\" - } - } - } +terraform { + required_providers { + iterative = { + source = \\"iterative/iterative\\" + } + } +} - provider \\"iterative\\" {} +provider \\"iterative\\" {} - resource \\"iterative_cml_runner\\" \\"runner\\" { - repo = \\"https://\\" - token = \\"abc\\" - driver = \\"gitlab\\" - labels = \\"mylabel\\" - idle_timeout = 0 - name = \\"myrunner\\" - single = \\"true\\" - cloud = \\"aws\\" - region = \\"west\\" - instance_type = \\"mymachinetype\\" - instance_gpu = \\"mygputype\\" - instance_hdd_size = 50 - ssh_private = \\"myprivate\\" - spot = true - spot_price = 0.0001 - - - } - " - `); +resource \\"iterative_cml_runner\\" \\"runner\\" { + repo = \\"https://\\" + token = \\"abc\\" + driver = \\"gitlab\\" + labels = \\"mylabel\\" + idle_timeout = 0 + name = \\"myrunner\\" + single = \\"true\\" + cloud = \\"aws\\" + region = \\"west\\" + instance_type = \\"mymachinetype\\" + instance_gpu = \\"mygputype\\" + instance_hdd_size = 50 + ssh_private = \\"myprivate\\" + spot = true + spot_price = 0.0001 + + + +} +" +`); + }); + + test('basic settings with metadata', async () => { + const output = iterativeCmlRunnerTpl({ + repo: 'https://', + token: 'abc', + driver: 'gitlab', + labels: 'mylabel', + idleTimeout: 300, + name: 'myrunner', + single: true, + cloud: 'aws', + region: 'west', + type: 'mymachinetype', + gpu: 'mygputype', + hddSize: 50, + sshPrivate: 'myprivate', + spot: true, + spotPrice: '0.0001', + metadata: { one: 'value', two: null } + }); + expect(output).toMatchInlineSnapshot(` +" + +terraform { + required_providers { + iterative = { + source = \\"iterative/iterative\\" + } + } +} + +provider \\"iterative\\" {} + + +resource \\"iterative_cml_runner\\" \\"runner\\" { + repo = \\"https://\\" + token = \\"abc\\" + driver = \\"gitlab\\" + labels = \\"mylabel\\" + idle_timeout = 300 + name = \\"myrunner\\" + single = \\"true\\" + cloud = \\"aws\\" + region = \\"west\\" + instance_type = \\"mymachinetype\\" + instance_gpu = \\"mygputype\\" + instance_hdd_size = 50 + ssh_private = \\"myprivate\\" + spot = true + spot_price = 0.0001 + + + metadata = { + one = \\"value\\" + two = \\"\\" + } +} +" +`); }); test('Startup script', async () => { @@ -171,39 +234,40 @@ describe('Terraform tests', () => { startupScript: 'c3VkbyBlY2hvICdoZWxsbyB3b3JsZCcgPj4gL3Vzci9oZWxsby50eHQ=' }); expect(output).toMatchInlineSnapshot(` - " +" - terraform { - required_providers { - iterative = { - source = \\"iterative/iterative\\" - } - } - } +terraform { + required_providers { + iterative = { + source = \\"iterative/iterative\\" + } + } +} - provider \\"iterative\\" {} +provider \\"iterative\\" {} - resource \\"iterative_cml_runner\\" \\"runner\\" { - repo = \\"https://\\" - token = \\"abc\\" - driver = \\"gitlab\\" - labels = \\"mylabel\\" - idle_timeout = 300 - name = \\"myrunner\\" - single = \\"true\\" - cloud = \\"aws\\" - region = \\"west\\" - instance_type = \\"mymachinetype\\" - instance_gpu = \\"mygputype\\" - instance_hdd_size = 50 - ssh_private = \\"myprivate\\" - spot = true - spot_price = 0.0001 - startup_script = \\"c3VkbyBlY2hvICdoZWxsbyB3b3JsZCcgPj4gL3Vzci9oZWxsby50eHQ=\\" - - } - " - `); +resource \\"iterative_cml_runner\\" \\"runner\\" { + repo = \\"https://\\" + token = \\"abc\\" + driver = \\"gitlab\\" + labels = \\"mylabel\\" + idle_timeout = 300 + name = \\"myrunner\\" + single = \\"true\\" + cloud = \\"aws\\" + region = \\"west\\" + instance_type = \\"mymachinetype\\" + instance_gpu = \\"mygputype\\" + instance_hdd_size = 50 + ssh_private = \\"myprivate\\" + spot = true + spot_price = 0.0001 + startup_script = \\"c3VkbyBlY2hvICdoZWxsbyB3b3JsZCcgPj4gL3Vzci9oZWxsby50eHQ=\\" + + +} +" +`); }); });