Skip to content

Commit

Permalink
Minor improvements + Add test for HandleRollingUpgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
TwiN committed May 7, 2020
1 parent 6fcaad5 commit d34c964
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 107 deletions.
68 changes: 0 additions & 68 deletions cloudtest/aws.go

This file was deleted.

132 changes: 132 additions & 0 deletions cloudtest/cloudtest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package cloudtest

import (
"errors"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/autoscaling"
"github.com/aws/aws-sdk-go/service/autoscaling/autoscalingiface"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
)

type MockEC2Service struct {
ec2iface.EC2API

Counter map[string]int64
Templates []*ec2.LaunchTemplate
}

func NewMockEC2Service(templates []*ec2.LaunchTemplate) *MockEC2Service {
return &MockEC2Service{
Counter: make(map[string]int64),
Templates: templates,
}
}

func (m *MockEC2Service) DescribeLaunchTemplates(_ *ec2.DescribeLaunchTemplatesInput) (*ec2.DescribeLaunchTemplatesOutput, error) {
m.Counter["DescribeLaunchTemplates"]++
output := &ec2.DescribeLaunchTemplatesOutput{
LaunchTemplates: m.Templates,
}
return output, nil
}

func (m *MockEC2Service) DescribeLaunchTemplateByID(input *ec2.DescribeLaunchTemplatesInput) (*ec2.LaunchTemplate, error) {
m.Counter["DescribeLaunchTemplateByID"]++
for _, template := range m.Templates {
if template.LaunchTemplateId == input.LaunchTemplateIds[0] {
return template, nil
}
if template.LaunchTemplateName == input.LaunchTemplateNames[0] {
return template, nil
}
}
return nil, errors.New("not found")
}

func CreateTestEc2Instance(id string) *ec2.Instance {
instance := &ec2.Instance{
InstanceId: aws.String(id),
}
return instance
}

type MockAutoScalingService struct {
autoscalingiface.AutoScalingAPI

Counter map[string]int64
AutoScalingGroups map[string]*autoscaling.Group
}

func NewMockAutoScalingService(autoScalingGroups []*autoscaling.Group) *MockAutoScalingService {
service := &MockAutoScalingService{
Counter: make(map[string]int64),
AutoScalingGroups: make(map[string]*autoscaling.Group),
}
for _, autoScalingGroup := range autoScalingGroups {
service.AutoScalingGroups[aws.StringValue(autoScalingGroup.AutoScalingGroupName)] = autoScalingGroup
}
return service
}

func (m *MockAutoScalingService) TerminateInstanceInAutoScalingGroup(_ *autoscaling.TerminateInstanceInAutoScalingGroupInput) (*autoscaling.TerminateInstanceInAutoScalingGroupOutput, error) {
m.Counter["TerminateInstanceInAutoScalingGroup"]++
return &autoscaling.TerminateInstanceInAutoScalingGroupOutput{}, nil
}

func (m *MockAutoScalingService) DescribeAutoScalingGroups(input *autoscaling.DescribeAutoScalingGroupsInput) (*autoscaling.DescribeAutoScalingGroupsOutput, error) {
m.Counter["DescribeAutoScalingGroups"]++
var autoScalingGroups []*autoscaling.Group
for _, autoScalingGroupName := range input.AutoScalingGroupNames {
for _, autoScalingGroup := range m.AutoScalingGroups {
if aws.StringValue(autoScalingGroupName) == aws.StringValue(autoScalingGroup.AutoScalingGroupName) {
autoScalingGroups = append(autoScalingGroups, autoScalingGroup)
}
}
}
return &autoscaling.DescribeAutoScalingGroupsOutput{
AutoScalingGroups: autoScalingGroups,
}, nil
}

func (m *MockAutoScalingService) SetDesiredCapacity(input *autoscaling.SetDesiredCapacityInput) (*autoscaling.SetDesiredCapacityOutput, error) {
m.Counter["SetDesiredCapacity"]++
m.AutoScalingGroups[aws.StringValue(input.AutoScalingGroupName)].SetDesiredCapacity(aws.Int64Value(input.DesiredCapacity))
return &autoscaling.SetDesiredCapacityOutput{}, nil
}

func (m *MockAutoScalingService) UpdateAutoScalingGroup(_ *autoscaling.UpdateAutoScalingGroupInput) (*autoscaling.UpdateAutoScalingGroupOutput, error) {
m.Counter["UpdateAutoScalingGroup"]++
return &autoscaling.UpdateAutoScalingGroupOutput{}, nil
}

func CreateTestAutoScalingGroup(name, launchConfigurationName string, launchTemplateSpecification *autoscaling.LaunchTemplateSpecification, instances []*autoscaling.Instance) *autoscaling.Group {
asg := &autoscaling.Group{
AutoScalingGroupName: aws.String(name),
Instances: instances,
DesiredCapacity: aws.Int64(int64(len(instances))),
MinSize: aws.Int64(0),
MaxSize: aws.Int64(999),
}
if len(launchConfigurationName) != 0 {
asg.SetLaunchConfigurationName(launchConfigurationName)
}
if launchTemplateSpecification != nil {
asg.SetLaunchTemplate(launchTemplateSpecification)
}
return asg
}

func CreateTestAutoScalingInstance(id, launchConfigurationName string, launchTemplateSpecification *autoscaling.LaunchTemplateSpecification, lifeCycleState string) *autoscaling.Instance {
instance := &autoscaling.Instance{
LifecycleState: aws.String(lifeCycleState),
InstanceId: aws.String(id),
}
if len(launchConfigurationName) != 0 {
instance.SetLaunchConfigurationName(launchConfigurationName)
}
if launchTemplateSpecification != nil {
instance.SetLaunchTemplate(launchTemplateSpecification)
}
return instance
}
29 changes: 21 additions & 8 deletions k8stest/k8stest.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,38 @@ import (

type MockKubernetesClient struct {
Counter map[string]int64
nodes []v1.Node
pods []v1.Pod
Nodes map[string]v1.Node
Pods map[string]v1.Pod
}

func NewMockKubernetesClient(nodes []v1.Node, pods []v1.Pod) *MockKubernetesClient {
return &MockKubernetesClient{
client := &MockKubernetesClient{
Counter: make(map[string]int64),
nodes: nodes,
pods: pods,
Nodes: make(map[string]v1.Node),
Pods: make(map[string]v1.Pod),
}
for _, node := range nodes {
client.Nodes[node.Name] = node
}
for _, pod := range pods {
client.Pods[pod.Name] = pod
}
return client
}

func (mock *MockKubernetesClient) GetNodes() ([]v1.Node, error) {
mock.Counter["GetNodes"]++
return mock.nodes, nil
var nodes []v1.Node
for _, node := range mock.Nodes {
nodes = append(nodes, node)
}
return nodes, nil
}

func (mock *MockKubernetesClient) GetPodsInNode(node string) ([]v1.Pod, error) {
mock.Counter["GetPodsInNode"]++
var pods []v1.Pod
for _, pod := range mock.pods {
for _, pod := range mock.Pods {
if pod.Spec.NodeName == node {
pods = append(pods, pod)
}
Expand All @@ -38,7 +49,7 @@ func (mock *MockKubernetesClient) GetPodsInNode(node string) ([]v1.Pod, error) {

func (mock *MockKubernetesClient) GetNodeByHostName(hostName string) (*v1.Node, error) {
mock.Counter["GetNodeByHostName"]++
for _, node := range mock.nodes {
for _, node := range mock.Nodes {
// For the sake of simplicity, we'll just assume that the host name is the same as the node name
if node.Name == hostName {
return &node, nil
Expand All @@ -49,6 +60,7 @@ func (mock *MockKubernetesClient) GetNodeByHostName(hostName string) (*v1.Node,

func (mock *MockKubernetesClient) UpdateNode(node *v1.Node) error {
mock.Counter["UpdateNode"]++
mock.Nodes[node.Name] = *node
return nil
}

Expand All @@ -68,6 +80,7 @@ func CreateTestNode(name string, allocatableCpu, allocatableMemory string) v1.No
},
}
node.SetName(name)
node.SetAnnotations(make(map[string]string))
return node
}

Expand Down
Loading

0 comments on commit d34c964

Please sign in to comment.