forked from tinkerbell/ipxedust
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcmd_test.go
143 lines (136 loc) · 4.11 KB
/
cmd_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package ipxedust
import (
"context"
"errors"
"flag"
"fmt"
"testing"
"time"
"github.com/go-logr/logr"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/phayes/freeport"
)
func TestCommand_RegisterFlags(t *testing.T) {
tests := []struct {
name string
want *flag.FlagSet
}{
{"success", func() *flag.FlagSet {
c := &Command{}
fs := flag.NewFlagSet("ipxe", flag.ExitOnError)
fs.StringVar(&c.TFTPAddr, "tftp-addr", "0.0.0.0:69", "TFTP server address")
fs.DurationVar(&c.TFTPTimeout, "tftp-timeout", time.Second*5, "TFTP server timeout")
fs.StringVar(&c.HTTPAddr, "http-addr", "0.0.0.0:8080", "HTTP server address")
fs.DurationVar(&c.HTTPTimeout, "http-timeout", time.Second*5, "HTTP server timeout")
fs.StringVar(&c.LogLevel, "log-level", "info", "Log level")
fs.BoolVar(&c.EnableTFTPSinglePort, "tftp-single-port", false, "Enable single port mode for TFTP server (needed for container deploys)")
return fs
}()},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Command{}
fs := flag.NewFlagSet("ipxe", flag.ExitOnError)
c.RegisterFlags(fs)
if diff := cmp.Diff(fs, tt.want, cmp.AllowUnexported(flag.FlagSet{}), cmpopts.IgnoreFields(flag.FlagSet{}, "Usage")); diff != "" {
t.Fatal(diff)
}
})
}
}
func getPort() int {
port, _ := freeport.GetFreePort()
return port
}
func TestCommand_Run(t *testing.T) {
tests := []struct {
name string
cmd *Command
wantErr error
}{
{"success", &Command{TFTPAddr: fmt.Sprintf("0.0.0.0:%v", getPort()), HTTPAddr: fmt.Sprintf("0.0.0.0:%v", getPort())}, nil},
{"fail permission denied", &Command{TFTPAddr: "127.0.0.1:80"}, fmt.Errorf("listen udp 127.0.0.1:80: bind: permission denied")},
{"fail parse error", &Command{TFTPAddr: "127.0.0.1:AF"}, fmt.Errorf(`invalid port "AF" parsing "127.0.0.1:AF"`)},
{"fail parse error", &Command{HTTPAddr: "127.0.0.1:AF"}, fmt.Errorf(`invalid port "AF" parsing "127.0.0.1:AF"`)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := tt.cmd
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
errChan := make(chan error, 1)
go func() {
errChan <- c.Run(ctx)
}()
<-ctx.Done()
got := <-errChan
if errors.Is(got, context.DeadlineExceeded) {
got = nil
}
if diff := cmp.Diff(fmt.Sprint(got), fmt.Sprint(tt.wantErr)); diff != "" {
t.Fatal(diff)
}
})
}
}
func TestCommand_Validate(t *testing.T) {
tests := []struct {
name string
cmd *Command
wantErr error
}{
{"success", &Command{
TFTPAddr: "0.0.0.0:69",
TFTPTimeout: 5 * time.Second,
HTTPAddr: "0.0.0.0:8080",
HTTPTimeout: 5 * time.Second,
Log: logr.Discard(),
LogLevel: "info",
}, nil},
{"fail", &Command{
TFTPTimeout: 5 * time.Second,
HTTPAddr: "0.0.0.0:8080",
HTTPTimeout: 5 * time.Second,
Log: logr.Discard(),
LogLevel: "info",
}, fmt.Errorf(`Key: 'Command.TFTPAddr' Error:Field validation for 'TFTPAddr' failed on the 'required' tag`)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := tt.cmd
got := c.Validate()
if diff := cmp.Diff(fmt.Sprint(got), fmt.Sprint(tt.wantErr)); diff != "" {
t.Fatal(diff)
}
})
}
}
func TestExecute(t *testing.T) {
tests := []struct {
name string
args []string
wantErr error
}{
{"success", []string{fmt.Sprintf("--tftp-addr=0.0.0.0:%v", getPort()), fmt.Sprintf("--http-addr=0.0.0.0:%v", getPort())}, nil},
{"fail validation", []string{"--tftp-addr=0.0.0.0:AF", "--log-level=debug"}, fmt.Errorf(`Key: 'Command.TFTPAddr' Error:Field validation for 'TFTPAddr' failed on the 'hostname_port' tag`)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
errChan := make(chan error, 1)
go func() {
errChan <- Execute(ctx, tt.args)
}()
<-ctx.Done()
got := <-errChan
if errors.Is(got, context.DeadlineExceeded) {
got = nil
}
if diff := cmp.Diff(fmt.Sprint(got), fmt.Sprint(tt.wantErr)); diff != "" {
t.Fatal(diff)
}
})
}
}