diff --git a/pkg/file/file.go b/pkg/file/file.go index 6c8fdd1..88f583b 100644 --- a/pkg/file/file.go +++ b/pkg/file/file.go @@ -73,44 +73,63 @@ func NewFileFromReader(r io.Reader) (File, error) { return nil, errors.New("invalid file reader") } - dummy := &dummyFile{} - - decoder := json.NewDecoder(r) - jsonDecodeErr := decoder.Decode(dummy) + // Take a peek and see if we encounter '{' (would imply the contents is JSON) + preview := make([]byte, 1024) + n, err := io.ReadFull(r, preview) + switch { + case err == io.ErrUnexpectedEOF: + preview = preview[:n] + r = bytes.NewReader(preview) + case err != nil: + return nil, err + default: + r = io.MultiReader(bytes.NewReader(preview), r) + } - // reset file reader - // need to read first block to detect json or metro format - // after that, need to reset seek point of reader - if sk, ok := r.(io.Seeker); ok { - sk.Seek(0, io.SeekStart) + // Look for the start of JSON + var isJSON bool + for i := range preview { + if preview[i] == '{' { + isJSON = true + break + } } - if jsonDecodeErr != nil { - // Parse metro file + // Decode contents as Metro2 formatting when it's not JSON + if !isJSON { return NewReader(r).Read() } - // Parse json file - if dummy.Header == nil { - return nil, errors.New("invalid json file") + // Determine the file format + var buf bytes.Buffer + r = io.TeeReader(r, &buf) + + var dummy dummyFile + err = json.NewDecoder(r).Decode(&dummy) + if err != nil { + return nil, fmt.Errorf("reading header: %w", err) } fileFormat := utils.CharacterFileFormat - if dummy.Header.RecordDescriptorWord == lib.UnpackedRecordLength { - fileFormat = utils.CharacterFileFormat - } else if dummy.Header.BlockDescriptorWord > 0 { - fileFormat = utils.PackedFileFormat + if dummy.Header != nil { + if dummy.Header.RecordDescriptorWord == lib.UnpackedRecordLength { + fileFormat = utils.CharacterFileFormat + } else if dummy.Header.BlockDescriptorWord > 0 { + fileFormat = utils.PackedFileFormat + } } + // Decode the file as JSON now f, err := NewFile(fileFormat) if err != nil { return nil, err } - if err = decoder.Decode(f); err != nil { - return nil, err + r = io.MultiReader(&buf, r) + err = json.NewDecoder(r).Decode(f) + if err != nil { + return f, fmt.Errorf("reading file: %w", err) } - return f, nil } diff --git a/pkg/server/server.go b/pkg/server/server.go index a374622..d50182b 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -16,19 +16,24 @@ import ( ) func parseInputFromRequest(r *http.Request) (file.File, error) { - src, _, err := r.FormFile("file") - if err != nil { - mf, err := file.NewFileFromReader(r.Body) + contentType := strings.ToLower(r.Header.Get("Content-Type")) + if strings.HasPrefix(contentType, "multipart/") { + src, _, err := r.FormFile("file") if err != nil { - return nil, err + return nil, fmt.Errorf("reading multipart request: %w", err) } + defer src.Close() + mf, err := file.NewFileFromReader(src) + if err != nil { + return nil, fmt.Errorf("parsing file as multipart: %w", err) + } return mf, nil } - mf, err := file.NewFileFromReader(src) + mf, err := file.NewFileFromReader(r.Body) if err != nil { - return nil, err + return nil, fmt.Errorf("parsing request body as reader: %w", err) } return mf, nil } diff --git a/pkg/server/suite_test.go b/pkg/server/suite_test.go index 6863847..6249022 100644 --- a/pkg/server/suite_test.go +++ b/pkg/server/suite_test.go @@ -218,7 +218,7 @@ func (t *ServerTest) TestWithInvalidForm(c *check.C) { c.Assert(recorder.Code, check.Equals, http.StatusBadRequest) } -func (t *ServerTest) TestPrintWithInvalidData(c *check.C) { +func (t *ServerTest) TestPrintWithoutContentType(c *check.C) { writer, body := t.getWriter("base_segment.json", c) err := writer.WriteField("format", "json") c.Assert(err, check.IsNil) @@ -227,10 +227,10 @@ func (t *ServerTest) TestPrintWithInvalidData(c *check.C) { recorder, request := t.makeRequest(http.MethodPost, "/print", body.String(), c) request.Header.Set("Content-Type", writer.FormDataContentType()) t.testServer.ServeHTTP(recorder, request) - c.Assert(recorder.Code, check.Equals, http.StatusBadRequest) + c.Assert(recorder.Code, check.Equals, http.StatusOK) } -func (t *ServerTest) TestConvertWithInvalidData(c *check.C) { +func (t *ServerTest) TestConvertWithoutContentType(c *check.C) { writer, body := t.getWriter("base_segment.json", c) err := writer.WriteField("format", "json") c.Assert(err, check.IsNil) @@ -239,7 +239,7 @@ func (t *ServerTest) TestConvertWithInvalidData(c *check.C) { recorder, request := t.makeRequest(http.MethodPost, "/convert", body.String(), c) request.Header.Set("Content-Type", writer.FormDataContentType()) t.testServer.ServeHTTP(recorder, request) - c.Assert(recorder.Code, check.Equals, http.StatusBadRequest) + c.Assert(recorder.Code, check.Equals, http.StatusOK) } func (t *ServerTest) TestConvertWithValidJsonRequest(c *check.C) {