Skip to content

Commit

Permalink
add train & ask cli
Browse files Browse the repository at this point in the history
  • Loading branch information
kevwan committed Sep 21, 2021
1 parent d4dd2a8 commit f697da8
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 18 deletions.
10 changes: 5 additions & 5 deletions bot/adapters/logic/closestmatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ func (match *closestMatch) processExactMatch(responses map[string]int) []Answer
}

func (match *closestMatch) processSimilarMatch(text string) []Answer {
result, ok := mr.MapReduce(generator(match, text), mapper(match), reducer(match))
if !ok {
result, err := mr.MapReduce(generator(match, text), mapper(match), reducer(match))
if err != nil{
return nil
}

Expand Down Expand Up @@ -146,7 +146,7 @@ func (top *topOccurAnswers) put(answer string, occurrence int) {
}

func generator(match *closestMatch, text string) mr.GenerateFunc {
return func(source chan interface{}) {
return func(source chan<- interface{}) {
keys := match.storage.Search(text)
if match.verbose {
printMatches(keys)
Expand All @@ -163,7 +163,7 @@ func generator(match *closestMatch, text string) mr.GenerateFunc {
}

func mapper(match *closestMatch) mr.MapperFunc {
return func(data interface{}, writer mr.Writer, cancel func()) {
return func(data interface{}, writer mr.Writer, cancel func(error)) {
tops := newTopScoreQuestions(match.tops)
pair := data.(sourceAndTargets)
for i := range pair.targets {
Expand All @@ -179,7 +179,7 @@ func mapper(match *closestMatch) mr.MapperFunc {
}

func reducer(match *closestMatch) mr.ReducerFunc {
return func(input chan interface{}, writer mr.Writer, cancel func()) {
return func(input <-chan interface{}, writer mr.Writer, cancel func(error)) {
tops := newTopScoreQuestions(match.tops)
for each := range input {
qs := each.(*topScoreQuestions)
Expand Down
9 changes: 5 additions & 4 deletions bot/adapters/storage/memorystorage.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,14 @@ func (storage *memoryStorage) buildIndex(keys []string) map[string][]int {
close(channel)
}()

result, ok := mr.MapReduce(func(source chan interface{}) {
result, err := mr.MapReduce(func(source chan<- interface{}) {
chunks := splitStrings(keys, chunkSize)
for i := range chunks {
source <- chunks[i]
}
}, storage.mapper, storage.reducer)
if !ok {
if err != nil {
fmt.Printf("error: %v\n", err)
return nil
}

Expand Down Expand Up @@ -299,7 +300,7 @@ func (storage *memoryStorage) generateSearchResults(ids []int) []string {
return result
}

func (storage *memoryStorage) mapper(data interface{}, writer mr.Writer, cancel func()) {
func (storage *memoryStorage) mapper(data interface{}, writer mr.Writer, cancel func(error)) {
indexes := make(map[string][]int)
chunk := data.(*keyChunk)

Expand Down Expand Up @@ -331,7 +332,7 @@ func (storage *memoryStorage) mapper(data interface{}, writer mr.Writer, cancel
writer.Write(indexes)
}

func (storage *memoryStorage) reducer(input chan interface{}, writer mr.Writer, cancel func()) {
func (storage *memoryStorage) reducer(input <-chan interface{}, writer mr.Writer, cancel func(error)) {
indexes := make(map[string]map[int]struct{})
for each := range input {
chunkIndexes := each.(map[string][]int)
Expand Down
58 changes: 58 additions & 0 deletions cli/ask/ask.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package main

import (
"bufio"
"flag"
"fmt"
"log"
"os"
"time"

"github.com/kevwan/chatbot/bot"
"github.com/kevwan/chatbot/bot/adapters/logic"
"github.com/kevwan/chatbot/bot/adapters/storage"
)

const tops = 5

var (
verbose = flag.Bool("v", false, "verbose mode")
storeFile = flag.String("c", "corpus.gob", "the file to store corpora")
)

func main() {
flag.Parse()

store, err := storage.NewSeparatedMemoryStorage(*storeFile)
if err != nil {
log.Fatal(err)
}

chatbot := &bot.ChatBot{
LogicAdapter: logic.NewClosestMatch(store, tops),
}
if *verbose {
chatbot.LogicAdapter.SetVerbose()
}

scanner := bufio.NewScanner(os.Stdin)
for {
fmt.Print("Q: ")
scanner.Scan()
question := scanner.Text()
if question == "exit" {
break
}

startTime := time.Now()
answers := chatbot.GetResponse(question)
for i, answer := range answers {
fmt.Printf("%d: %s\n", i+1, answer.Content)
if *verbose {
fmt.Printf("%d: %s\tConfidence: %.3f\t%s\n", i+1, answer.Content,
answer.Confidence, time.Since(startTime))
}
}
fmt.Println(time.Since(startTime))
}
}
39 changes: 39 additions & 0 deletions cli/train/train.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package main

import (
"flag"
"log"
"strings"

"github.com/kevwan/chatbot/bot"
"github.com/kevwan/chatbot/bot/adapters/storage"
)

var (
corpora = flag.String("i", "", "the corpora files, comma to separate multiple files")
storeFile = flag.String("o", "corpus.gob", "the file to store corpora")
printMemStats = flag.Bool("m", false, "enable printing memory stats")
)

func main() {
flag.Parse()

if len(*corpora) == 0 {
flag.Usage()
return
}

store, err := storage.NewSeparatedMemoryStorage(*storeFile)
if err != nil {
log.Fatal(err)
}

chatbot := &bot.ChatBot{
PrintMemStats: *printMemStats,
Trainer: bot.NewCorpusTrainer(store),
StorageAdapter: store,
}
if err := chatbot.Train(strings.Split(*corpora, ",")); err != nil {
log.Fatal(err)
}
}
10 changes: 1 addition & 9 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
module github.com/kevwan/chatbot

go 1.17
go 1.15

require (
github.com/tal-tech/go-zero v1.2.1
github.com/wangbin/jiebago v0.3.2
)

require (
github.com/spaolacci/murmur3 v1.1.0 // indirect
go.opentelemetry.io/otel v1.0.0-RC2 // indirect
go.opentelemetry.io/otel/trace v1.0.0-RC2 // indirect
go.uber.org/automaxprocs v1.3.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
Expand Down

0 comments on commit f697da8

Please sign in to comment.