diff --git a/.gitignore b/.gitignore index 0b4e646..8723181 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,10 @@ _test *.[568vq] [568vq].out +# IntelliJ +*.iml +.idea + *.cgo1.go *.cgo2.c _cgo_defun.c @@ -24,5 +28,6 @@ _testmain.go *.exe *.test *.prof +*.sw[nop] mct.go diff --git a/README.md b/README.md index 24b673a..b10ff34 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ func main() { // Let's get some context about these images urls := []string{"http://www.clarifai.com/img/metro-north.jpg", "http://www.clarifai.com/img/metro-north.jpg"} // Give it to Clarifai to run their magic - tag_data, err := client.Tag(urls, nil) + tag_data, err := client.Tag(clarifai.TagRequest{URLs: urls}) if err != nil { fmt.Println(err) diff --git a/client.go b/client.go index 03a6627..59a5f4c 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package clarifai import ( + "bytes" "encoding/json" "errors" "io/ioutil" @@ -82,20 +83,26 @@ func (client *Client) requestAccessToken() error { return nil } -func (client *Client) commonHTTPRequest(values url.Values, endpoint, verb string, retry bool) ([]byte, error) { - if values == nil { - values = url.Values{} +func (client *Client) commonHTTPRequest(jsonBody interface{}, endpoint, verb string, retry bool) ([]byte, error) { + if jsonBody == nil { + jsonBody = struct{}{} } - req, err := http.NewRequest(verb, client.buildURL(endpoint), strings.NewReader(values.Encode())) + body, err := json.Marshal(jsonBody) if err != nil { return nil, err } - req.Header.Set("Content-Length", strconv.Itoa(len(values.Encode()))) + req, err := http.NewRequest(verb, client.buildURL(endpoint), bytes.NewReader(body)) + + if err != nil { + return nil, err + } + + req.Header.Set("Content-Length", strconv.Itoa(len(body))) req.Header.Set("Authorization", "Bearer "+client.AccessToken) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Content-Type", "application/json") httpClient := &http.Client{} res, err := httpClient.Do(req) @@ -118,7 +125,7 @@ func (client *Client) commonHTTPRequest(values url.Values, endpoint, verb string if err != nil { return nil, err } - return client.commonHTTPRequest(values, endpoint, verb, true) + return client.commonHTTPRequest(jsonBody, endpoint, verb, true) } return nil, errors.New("TOKEN_INVALID") case 429: diff --git a/requests.go b/requests.go index 1ff9c65..9ad8f07 100644 --- a/requests.go +++ b/requests.go @@ -3,8 +3,6 @@ package clarifai import ( "encoding/json" "errors" - "net/url" - "strings" ) // InfoResp represents the expected JSON response from /info/ @@ -27,6 +25,13 @@ type InfoResp struct { } } +// TagRequest represents a JSON request for /tag/ +type TagRequest struct { + URLs []string `json:"url"` + LocalIDs []string `json:"local_ids,omitempty"` + Model string `json:"model,omitempty"` +} + // TagResp represents the expected JSON response from /tag/ type TagResp struct { StatusCode string `json:"status_code"` @@ -60,13 +65,13 @@ type TagResult struct { // FeedbackForm is used to send feedback back to Clarifai type FeedbackForm struct { - DocIDs []string - URLs []string - AddTags []string - RemoveTags []string - DissimilarDocIDs []string - SimilarDocIDs []string - SearchClick []string + DocIDs []string `json:"docids,omitempty"` + URLs []string `json:"url,omitempty"` + AddTags []string `json:"add_tags,omitempty"` + RemoveTags []string `json:"remove_tags,omitempty"` + DissimilarDocIDs []string `json:"dissimilar_docids,omitempty"` + SimilarDocIDs []string `json:"similar_docids,omitempty"` + SearchClick []string `json:"search_click,omitempty"` } // FeedbackResp is the expected response from /feedback/ @@ -90,22 +95,12 @@ func (client *Client) Info() (*InfoResp, error) { } // Tag allows the client to request tag data on a single, or multiple photos -func (client *Client) Tag(urls, localIDs []string) (*TagResp, error) { - if urls == nil { +func (client *Client) Tag(req TagRequest) (*TagResp, error) { + if len(req.URLs) < 1 { return nil, errors.New("Requires at least one url") } - form := url.Values{} - for _, url := range urls { - form.Add("url", url) - } - if localIDs != nil { - for _, localID := range localIDs { - form.Add("local_id", localID) - } - } - - res, err := client.commonHTTPRequest(form, "tag", "POST", false) + res, err := client.commonHTTPRequest(req, "tag", "POST", false) if err != nil { return nil, err @@ -118,43 +113,15 @@ func (client *Client) Tag(urls, localIDs []string) (*TagResp, error) { } // Feedback allows the user to provide contextual feedback to Clarifai in order to improve their results -func (client *Client) Feedback(params FeedbackForm) (*FeedbackResp, error) { - if params.DocIDs == nil && params.URLs == nil { +func (client *Client) Feedback(form FeedbackForm) (*FeedbackResp, error) { + if form.DocIDs == nil && form.URLs == nil { return nil, errors.New("Requires at least one docid or url") } - if params.DocIDs != nil && params.URLs != nil { + if form.DocIDs != nil && form.URLs != nil { return nil, errors.New("Request must provide exactly one of the following fields: {'DocIDs', 'URLs'}") } - form := url.Values{} - - if params.DocIDs != nil { - form.Add("docids", strings.Join(params.DocIDs, ",")) - } else { - form.Add("url", strings.Join(params.URLs, ",")) - } - - if params.AddTags != nil { - form.Add("add_tags", strings.Join(params.AddTags, ",")) - } - - if params.RemoveTags != nil { - form.Add("remove_tags", strings.Join(params.RemoveTags, ",")) - } - - if params.DissimilarDocIDs != nil { - form.Add("dissimilar_docids", strings.Join(params.DissimilarDocIDs, ",")) - } - - if params.SimilarDocIDs != nil { - form.Add("similar_docids", strings.Join(params.SimilarDocIDs, ",")) - } - - if params.SearchClick != nil { - form.Add("search_click", strings.Join(params.SearchClick, ",")) - } - res, err := client.commonHTTPRequest(form, "feedback", "POST", false) feedbackres := new(FeedbackResp) diff --git a/requests_test.go b/requests_test.go index 109f0ec..53b5dd2 100644 --- a/requests_test.go +++ b/requests_test.go @@ -55,7 +55,7 @@ func TestTagMultiple(t *testing.T) { }) urls := []string{"http://www.clarifai.com/img/metro-north.jpg", "http://www.clarifai.com/img/metro-north.jpg"} - _, err := client.Tag(urls, nil) + _, err := client.Tag(TagRequest{URLs: urls}) if err != nil { t.Errorf("Tag() should not return error with valid request: %q\n", err)