From 8a0c758c0224199ceb4b09be4f5fc929fd6fe21b Mon Sep 17 00:00:00 2001 From: kevin Date: Wed, 6 Oct 2021 22:34:05 +0800 Subject: [PATCH] feat: support loading whole directory on trainning --- cli/train/train.go | 47 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/cli/train/train.go b/cli/train/train.go index 53624e9..40d175d 100644 --- a/cli/train/train.go +++ b/cli/train/train.go @@ -2,7 +2,9 @@ package main import ( "flag" + "fmt" "log" + "path/filepath" "strings" "github.com/kevwan/chatbot/bot" @@ -11,6 +13,7 @@ import ( var ( corpora = flag.String("i", "", "the corpora files, comma to separate multiple files") + dir = flag.String("d", "", "the directory to look for corpora files") storeFile = flag.String("o", "corpus.gob", "the file to store corpora") printMemStats = flag.Bool("m", false, "enable printing memory stats") ) @@ -18,7 +21,19 @@ var ( func main() { flag.Parse() - if len(*corpora) == 0 { + var files []string + if len(*dir) > 0 { + files = findCorporaFiles(*dir) + } + + var corporaFiles string + if len(files) > 0 { + corporaFiles = strings.Join(files, ",") + } + if len(*corpora) > 0 { + corporaFiles = strings.Join([]string{corporaFiles, *corpora}, ",") + } + if len(corporaFiles) == 0 { flag.Usage() return } @@ -33,7 +48,35 @@ func main() { Trainer: bot.NewCorpusTrainer(store), StorageAdapter: store, } - if err := chatbot.Train(strings.Split(*corpora, ",")); err != nil { + if err := chatbot.Train(strings.Split(corporaFiles, ",")); err != nil { log.Fatal(err) } } + +func findCorporaFiles(dir string) []string { + var files []string + + jsonFiles, err := filepath.Glob(filepath.Join(dir, "*.json")) + if err != nil { + fmt.Println(err) + return nil + } + + files = append(files, jsonFiles...) + + ymlFiles, err := filepath.Glob(filepath.Join(dir, "*.yml")) + if err != nil { + fmt.Println(err) + return nil + } + + files = append(files, ymlFiles...) + + yamlFiles, err := filepath.Glob(filepath.Join(dir, "*.yaml")) + if err != nil { + fmt.Println(err) + return nil + } + + return append(files, yamlFiles...) +}