Skip to content

Commit

Permalink
file: peek at contents to determine content type and packing level
Browse files Browse the repository at this point in the history
  • Loading branch information
adamdecaf committed Dec 21, 2023
1 parent 7bc96de commit 517d499
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 31 deletions.
61 changes: 40 additions & 21 deletions pkg/file/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,44 +73,63 @@ func NewFileFromReader(r io.Reader) (File, error) {
return nil, errors.New("invalid file reader")
}

dummy := &dummyFile{}

decoder := json.NewDecoder(r)
jsonDecodeErr := decoder.Decode(dummy)
// Take a peek and see if we encounter '{' (would imply the contents is JSON)
preview := make([]byte, 1024)
n, err := io.ReadFull(r, preview)
switch {
case err == io.ErrUnexpectedEOF:
preview = preview[:n]
r = bytes.NewReader(preview)
case err != nil:
return nil, err
default:
r = io.MultiReader(bytes.NewReader(preview), r)
}

// reset file reader
// need to read first block to detect json or metro format
// after that, need to reset seek point of reader
if sk, ok := r.(io.Seeker); ok {
sk.Seek(0, io.SeekStart)
// Look for the start of JSON
var isJSON bool
for i := range preview {
if preview[i] == '{' {
isJSON = true
break
}
}

if jsonDecodeErr != nil {
// Parse metro file
// Decode contents as Metro2 formatting when it's not JSON
if !isJSON {
return NewReader(r).Read()
}

// Parse json file
if dummy.Header == nil {
return nil, errors.New("invalid json file")
// Determine the file format
var buf bytes.Buffer
r = io.TeeReader(r, &buf)

var dummy dummyFile
err = json.NewDecoder(r).Decode(&dummy)
if err != nil {
return nil, fmt.Errorf("reading header: %w", err)
}

fileFormat := utils.CharacterFileFormat
if dummy.Header.RecordDescriptorWord == lib.UnpackedRecordLength {
fileFormat = utils.CharacterFileFormat
} else if dummy.Header.BlockDescriptorWord > 0 {
fileFormat = utils.PackedFileFormat
if dummy.Header != nil {
if dummy.Header.RecordDescriptorWord == lib.UnpackedRecordLength {
fileFormat = utils.CharacterFileFormat
} else if dummy.Header.BlockDescriptorWord > 0 {
fileFormat = utils.PackedFileFormat
}
}

// Decode the file as JSON now
f, err := NewFile(fileFormat)
if err != nil {
return nil, err
}

if err = decoder.Decode(f); err != nil {
return nil, err
r = io.MultiReader(&buf, r)
err = json.NewDecoder(r).Decode(f)
if err != nil {
return f, fmt.Errorf("reading file: %w", err)
}

return f, nil
}

Expand Down
17 changes: 11 additions & 6 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,24 @@ import (
)

func parseInputFromRequest(r *http.Request) (file.File, error) {
src, _, err := r.FormFile("file")
if err != nil {
mf, err := file.NewFileFromReader(r.Body)
contentType := strings.ToLower(r.Header.Get("Content-Type"))
if strings.HasPrefix(contentType, "multipart/") {
src, _, err := r.FormFile("file")
if err != nil {
return nil, err
return nil, fmt.Errorf("reading multipart request: %w", err)
}
defer src.Close()

mf, err := file.NewFileFromReader(src)
if err != nil {
return nil, fmt.Errorf("parsing file as multipart: %w", err)
}
return mf, nil
}

mf, err := file.NewFileFromReader(src)
mf, err := file.NewFileFromReader(r.Body)
if err != nil {
return nil, err
return nil, fmt.Errorf("parsing request body as reader: %w", err)
}
return mf, nil
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/server/suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func (t *ServerTest) TestWithInvalidForm(c *check.C) {
c.Assert(recorder.Code, check.Equals, http.StatusBadRequest)
}

func (t *ServerTest) TestPrintWithInvalidData(c *check.C) {
func (t *ServerTest) TestPrintWithoutContentType(c *check.C) {
writer, body := t.getWriter("base_segment.json", c)
err := writer.WriteField("format", "json")
c.Assert(err, check.IsNil)
Expand All @@ -227,10 +227,10 @@ func (t *ServerTest) TestPrintWithInvalidData(c *check.C) {
recorder, request := t.makeRequest(http.MethodPost, "/print", body.String(), c)
request.Header.Set("Content-Type", writer.FormDataContentType())
t.testServer.ServeHTTP(recorder, request)
c.Assert(recorder.Code, check.Equals, http.StatusBadRequest)
c.Assert(recorder.Code, check.Equals, http.StatusOK)
}

func (t *ServerTest) TestConvertWithInvalidData(c *check.C) {
func (t *ServerTest) TestConvertWithoutContentType(c *check.C) {
writer, body := t.getWriter("base_segment.json", c)
err := writer.WriteField("format", "json")
c.Assert(err, check.IsNil)
Expand All @@ -239,7 +239,7 @@ func (t *ServerTest) TestConvertWithInvalidData(c *check.C) {
recorder, request := t.makeRequest(http.MethodPost, "/convert", body.String(), c)
request.Header.Set("Content-Type", writer.FormDataContentType())
t.testServer.ServeHTTP(recorder, request)
c.Assert(recorder.Code, check.Equals, http.StatusBadRequest)
c.Assert(recorder.Code, check.Equals, http.StatusOK)
}

func (t *ServerTest) TestConvertWithValidJsonRequest(c *check.C) {
Expand Down

0 comments on commit 517d499

Please sign in to comment.