diff --git a/provision/ec2.go b/provision/ec2.go index 326e807..2e3fb2f 100644 --- a/provision/ec2.go +++ b/provision/ec2.go @@ -2,10 +2,11 @@ package provision import ( "fmt" - "github.com/aws/aws-sdk-go/aws/credentials" "strconv" "strings" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" @@ -40,10 +41,12 @@ func (p *EC2Provisioner) Provision(host BasicHost) (*ProvisionedHost, error) { } pro := host.Additional["pro"] + ports := host.Additional["ports"] + var vpcID = host.Additional["vpc-id"] var subnetID = host.Additional["subnet-id"] - groupID, name, err := p.createEC2SecurityGroup(vpcID, port, pro) + groupID, name, err := p.createEC2SecurityGroup(vpcID, port, pro, ports) if err != nil { return nil, err } @@ -85,6 +88,7 @@ func (p *EC2Provisioner) Provision(host BasicHost) (*ProvisionedHost, error) { return nil, fmt.Errorf("could not create host: %s", runResult.String()) } + // AE: not sure why this error isn't handled? _, err = p.ec2Provisioner.CreateTags(&ec2.CreateTagsInput{ Resources: []*string{runResult.Instances[0].InstanceId}, Tags: []*ec2.Tag{ @@ -247,9 +251,21 @@ func (p *EC2Provisioner) lookupID(request HostDeleteRequest) (string, error) { } // createEC2SecurityGroup creates a security group for the exit-node -func (p *EC2Provisioner) createEC2SecurityGroup(vpcID string, controlPort int, pro string) (*string, *string, error) { - ports := []int{80, 443, controlPort} - proPorts := []int{1024, 65535} +func (p *EC2Provisioner) createEC2SecurityGroup(vpcID string, controlPort int, pro, extraPorts string) (*string, *string, error) { + ports := []int{controlPort} + + proPortRange := []int{1024, 65535} + + if len(extraPorts) > 0 { + extraPorts, err := parsePorts(extraPorts) + if err != nil { + return nil, nil, err + } + ports = append(ports, extraPorts...) + + proPortRange = []int{} + } + groupName := "inlets-" + uuid.New().String() var input = &ec2.CreateSecurityGroupInput{ Description: aws.String("inlets security group"), @@ -271,8 +287,9 @@ func (p *EC2Provisioner) createEC2SecurityGroup(vpcID string, controlPort int, p return group.GroupId, &groupName, err } } - if pro == "true" { - err = p.createEC2SecurityGroupRule(*group.GroupId, proPorts[0], proPorts[1]) + + if pro == "true" && len(proPortRange) == 2 { + err = p.createEC2SecurityGroupRule(*group.GroupId, proPortRange[0], proPortRange[1]) if err != nil { return group.GroupId, &groupName, err } @@ -281,6 +298,22 @@ func (p *EC2Provisioner) createEC2SecurityGroup(vpcID string, controlPort int, p return group.GroupId, &groupName, nil } +func parsePorts(extraPorts string) ([]int, error) { + var ports []int + parts := strings.Split(extraPorts, ",") + for _, part := range parts { + if trimmed := strings.TrimSpace(part); len(trimmed) > 0 { + port, err := strconv.Atoi(trimmed) + if err != nil { + return nil, err + } + ports = append(ports, port) + } + } + + return ports, nil +} + func (p *EC2Provisioner) createEC2SecurityGroupRule(groupID string, fromPort, toPort int) error { _, err := p.ec2Provisioner.AuthorizeSecurityGroupIngress(&ec2.AuthorizeSecurityGroupIngressInput{ CidrIp: aws.String("0.0.0.0/0"), diff --git a/provision/ec2_test.go b/provision/ec2_test.go new file mode 100644 index 0000000..35b44ca --- /dev/null +++ b/provision/ec2_test.go @@ -0,0 +1,63 @@ +package provision + +import "testing" + +func Test_parsePorts_empty(t *testing.T) { + + ports, err := parsePorts("") + if err != nil { + t.Fatal(err) + } + + if len(ports) != 0 { + t.Fatalf("Expected empty slice, got %d", len(ports)) + } +} + +func Test_parsePorts_single(t *testing.T) { + + wantPort := 80 + str := "80" + ports, err := parsePorts(str) + if err != nil { + t.Fatal(err) + } + + if len(ports) != 1 { + t.Fatalf("Want single port, got %d", len(ports)) + } + + if ports[0] != wantPort { + t.Fatalf("Want port %d, got %d", wantPort, ports[0]) + } +} + +func Test_parsePorts_multiple(t *testing.T) { + + wantPorts := []int{27017, 22} + + str := "27017,22" + + ports, err := parsePorts(str) + if err != nil { + t.Fatal(err) + } + + if len(ports) != len(wantPorts) { + t.Fatalf("Want %d ports, got %d", len(wantPorts), len(ports)) + } + + found := 0 + + for _, port := range ports { + for _, wantPort := range wantPorts { + if port == wantPort { + found++ + } + } + } + + if found != len(wantPorts) { + t.Fatalf("Want %v ports, got %v", wantPorts, ports) + } +}