From f697da86526097768048c3367b8b7802c8b0eaf1 Mon Sep 17 00:00:00 2001 From: kevin Date: Tue, 21 Sep 2021 12:28:06 +0800 Subject: [PATCH] add train & ask cli --- bot/adapters/logic/closestmatch.go | 10 ++--- bot/adapters/storage/memorystorage.go | 9 +++-- cli/ask/ask.go | 58 +++++++++++++++++++++++++++ cli/train/train.go | 39 ++++++++++++++++++ go.mod | 10 +---- go.sum | 1 + 6 files changed, 109 insertions(+), 18 deletions(-) create mode 100644 cli/ask/ask.go create mode 100644 cli/train/train.go diff --git a/bot/adapters/logic/closestmatch.go b/bot/adapters/logic/closestmatch.go index 44287ba..3c21cf7 100644 --- a/bot/adapters/logic/closestmatch.go +++ b/bot/adapters/logic/closestmatch.go @@ -95,8 +95,8 @@ func (match *closestMatch) processExactMatch(responses map[string]int) []Answer } func (match *closestMatch) processSimilarMatch(text string) []Answer { - result, ok := mr.MapReduce(generator(match, text), mapper(match), reducer(match)) - if !ok { + result, err := mr.MapReduce(generator(match, text), mapper(match), reducer(match)) + if err != nil{ return nil } @@ -146,7 +146,7 @@ func (top *topOccurAnswers) put(answer string, occurrence int) { } func generator(match *closestMatch, text string) mr.GenerateFunc { - return func(source chan interface{}) { + return func(source chan<- interface{}) { keys := match.storage.Search(text) if match.verbose { printMatches(keys) @@ -163,7 +163,7 @@ func generator(match *closestMatch, text string) mr.GenerateFunc { } func mapper(match *closestMatch) mr.MapperFunc { - return func(data interface{}, writer mr.Writer, cancel func()) { + return func(data interface{}, writer mr.Writer, cancel func(error)) { tops := newTopScoreQuestions(match.tops) pair := data.(sourceAndTargets) for i := range pair.targets { @@ -179,7 +179,7 @@ func mapper(match *closestMatch) mr.MapperFunc { } func reducer(match *closestMatch) mr.ReducerFunc { - return func(input chan interface{}, writer mr.Writer, cancel func()) { + return func(input <-chan interface{}, writer mr.Writer, cancel func(error)) { tops := newTopScoreQuestions(match.tops) for each := range input { qs := each.(*topScoreQuestions) diff --git a/bot/adapters/storage/memorystorage.go b/bot/adapters/storage/memorystorage.go index 213023b..830c92a 100644 --- a/bot/adapters/storage/memorystorage.go +++ b/bot/adapters/storage/memorystorage.go @@ -187,13 +187,14 @@ func (storage *memoryStorage) buildIndex(keys []string) map[string][]int { close(channel) }() - result, ok := mr.MapReduce(func(source chan interface{}) { + result, err := mr.MapReduce(func(source chan<- interface{}) { chunks := splitStrings(keys, chunkSize) for i := range chunks { source <- chunks[i] } }, storage.mapper, storage.reducer) - if !ok { + if err != nil { + fmt.Printf("error: %v\n", err) return nil } @@ -299,7 +300,7 @@ func (storage *memoryStorage) generateSearchResults(ids []int) []string { return result } -func (storage *memoryStorage) mapper(data interface{}, writer mr.Writer, cancel func()) { +func (storage *memoryStorage) mapper(data interface{}, writer mr.Writer, cancel func(error)) { indexes := make(map[string][]int) chunk := data.(*keyChunk) @@ -331,7 +332,7 @@ func (storage *memoryStorage) mapper(data interface{}, writer mr.Writer, cancel writer.Write(indexes) } -func (storage *memoryStorage) reducer(input chan interface{}, writer mr.Writer, cancel func()) { +func (storage *memoryStorage) reducer(input <-chan interface{}, writer mr.Writer, cancel func(error)) { indexes := make(map[string]map[int]struct{}) for each := range input { chunkIndexes := each.(map[string][]int) diff --git a/cli/ask/ask.go b/cli/ask/ask.go new file mode 100644 index 0000000..d818a25 --- /dev/null +++ b/cli/ask/ask.go @@ -0,0 +1,58 @@ +package main + +import ( + "bufio" + "flag" + "fmt" + "log" + "os" + "time" + + "github.com/kevwan/chatbot/bot" + "github.com/kevwan/chatbot/bot/adapters/logic" + "github.com/kevwan/chatbot/bot/adapters/storage" +) + +const tops = 5 + +var ( + verbose = flag.Bool("v", false, "verbose mode") + storeFile = flag.String("c", "corpus.gob", "the file to store corpora") +) + +func main() { + flag.Parse() + + store, err := storage.NewSeparatedMemoryStorage(*storeFile) + if err != nil { + log.Fatal(err) + } + + chatbot := &bot.ChatBot{ + LogicAdapter: logic.NewClosestMatch(store, tops), + } + if *verbose { + chatbot.LogicAdapter.SetVerbose() + } + + scanner := bufio.NewScanner(os.Stdin) + for { + fmt.Print("Q: ") + scanner.Scan() + question := scanner.Text() + if question == "exit" { + break + } + + startTime := time.Now() + answers := chatbot.GetResponse(question) + for i, answer := range answers { + fmt.Printf("%d: %s\n", i+1, answer.Content) + if *verbose { + fmt.Printf("%d: %s\tConfidence: %.3f\t%s\n", i+1, answer.Content, + answer.Confidence, time.Since(startTime)) + } + } + fmt.Println(time.Since(startTime)) + } +} diff --git a/cli/train/train.go b/cli/train/train.go new file mode 100644 index 0000000..53624e9 --- /dev/null +++ b/cli/train/train.go @@ -0,0 +1,39 @@ +package main + +import ( + "flag" + "log" + "strings" + + "github.com/kevwan/chatbot/bot" + "github.com/kevwan/chatbot/bot/adapters/storage" +) + +var ( + corpora = flag.String("i", "", "the corpora files, comma to separate multiple files") + storeFile = flag.String("o", "corpus.gob", "the file to store corpora") + printMemStats = flag.Bool("m", false, "enable printing memory stats") +) + +func main() { + flag.Parse() + + if len(*corpora) == 0 { + flag.Usage() + return + } + + store, err := storage.NewSeparatedMemoryStorage(*storeFile) + if err != nil { + log.Fatal(err) + } + + chatbot := &bot.ChatBot{ + PrintMemStats: *printMemStats, + Trainer: bot.NewCorpusTrainer(store), + StorageAdapter: store, + } + if err := chatbot.Train(strings.Split(*corpora, ",")); err != nil { + log.Fatal(err) + } +} diff --git a/go.mod b/go.mod index 73be63b..e2baf3b 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,8 @@ module github.com/kevwan/chatbot -go 1.17 +go 1.15 require ( github.com/tal-tech/go-zero v1.2.1 github.com/wangbin/jiebago v0.3.2 ) - -require ( - github.com/spaolacci/murmur3 v1.1.0 // indirect - go.opentelemetry.io/otel v1.0.0-RC2 // indirect - go.opentelemetry.io/otel/trace v1.0.0-RC2 // indirect - go.uber.org/automaxprocs v1.3.0 // indirect - gopkg.in/yaml.v2 v2.4.0 // indirect -) diff --git a/go.sum b/go.sum index 0427674..f702909 100644 --- a/go.sum +++ b/go.sum @@ -484,6 +484,7 @@ golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=