From d1fdafac4aa37496ae7619733d7ccfbf7f3bf13d Mon Sep 17 00:00:00 2001 From: andyzhangx Date: Tue, 24 Dec 2024 07:33:48 +0000 Subject: [PATCH] test: add unit test for main function --- cmd/smbplugin/main.go | 21 ++++++---- cmd/smbplugin/main_test.go | 82 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 8 deletions(-) create mode 100644 cmd/smbplugin/main_test.go diff --git a/cmd/smbplugin/main.go b/cmd/smbplugin/main.go index 50a75a9f48f..efa04e0756f 100644 --- a/cmd/smbplugin/main.go +++ b/cmd/smbplugin/main.go @@ -51,6 +51,11 @@ var ( removeArchivedVolumePath = flag.Bool("remove-archived-volume-path", true, "remove archived volume path in DeleteVolume") ) +// exit is a separate function to handle program termination +var exit = func(code int) { + os.Exit(code) +} + func main() { flag.Parse() if *ver { @@ -59,15 +64,15 @@ func main() { klog.Fatalln(err) } fmt.Println(info) // nolint - os.Exit(0) - } - if *nodeID == "" { - // nodeid is not needed in controller component - klog.Warning("nodeid is empty") + } else { + if *nodeID == "" { + // nodeid is not needed in controller component + klog.Warning("nodeid is empty") + } + exportMetrics() + handle() } - exportMetrics() - handle() - os.Exit(0) + exit(0) } func handle() { diff --git a/cmd/smbplugin/main_test.go b/cmd/smbplugin/main_test.go new file mode 100644 index 00000000000..ec5c6849977 --- /dev/null +++ b/cmd/smbplugin/main_test.go @@ -0,0 +1,82 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "fmt" + "net" + "os" + "reflect" + "testing" +) + +func TestMain(t *testing.T) { + // Set the version flag to true + os.Args = []string{"cmd", "-ver"} + + // Capture stdout + old := os.Stdout + _, w, _ := os.Pipe() + os.Stdout = w + + // Replace exit function with mock function + var exitCode int + exit = func(code int) { + exitCode = code + } + + // Call main function + main() + + // Restore stdout + w.Close() + os.Stdout = old + exit = func(code int) { + os.Exit(code) + } + + if exitCode != 0 { + t.Errorf("Expected exit code 0, but got %d", exitCode) + } +} + +func TestTrapClosedConnErr(t *testing.T) { + tests := []struct { + err error + expectedErr error + }{ + { + err: net.ErrClosed, + expectedErr: nil, + }, + { + err: nil, + expectedErr: nil, + }, + { + err: fmt.Errorf("some error"), + expectedErr: fmt.Errorf("some error"), + }, + } + + for _, test := range tests { + err := trapClosedConnErr(test.err) + if !reflect.DeepEqual(err, test.expectedErr) { + t.Errorf("Expected error %v, but got %v", test.expectedErr, err) + } + } +}