diff --git a/internal/pkg/crawl/capture.go b/internal/pkg/crawl/capture.go
index 0d3be9c4..0e7308c8 100644
--- a/internal/pkg/crawl/capture.go
+++ b/internal/pkg/crawl/capture.go
@@ -467,10 +467,11 @@ func (c *Crawl) Capture(item *queue.Item) error {
outlinks = append(outlinks, URLsFromS3...)
} else {
- URLsFromXML, isSitemap, err := extractor.XML(resp)
+ URLsFromXML, isSitemap, err := extractor.XML(resp, false)
if err != nil {
c.Log.WithFields(c.genLogFields(err, item.URL, nil)).Error("unable to extract URLs from XML")
- } else {
+ }
+ if len(URLsFromXML) > 0 {
if isSitemap {
outlinks = append(outlinks, URLsFromXML...)
} else {
diff --git a/internal/pkg/crawl/extractor/xml.go b/internal/pkg/crawl/extractor/xml.go
index e2300317..2e8e60df 100644
--- a/internal/pkg/crawl/extractor/xml.go
+++ b/internal/pkg/crawl/extractor/xml.go
@@ -9,36 +9,37 @@ import (
"strings"
)
-func XML(resp *http.Response) (URLs []*url.URL, sitemap bool, err error) {
+var sitemapMarker = []byte("sitemaps.org/schemas/sitemap/")
+
+func XML(resp *http.Response, strict bool) (URLs []*url.URL, sitemap bool, err error) {
xmlBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, sitemap, err
}
- if strings.Contains(string(xmlBody), "sitemaps.org/schemas/sitemap/") {
+ if bytes.Contains(xmlBody, sitemapMarker) {
sitemap = true
}
- reader := bytes.NewReader(xmlBody)
- decoder := xml.NewDecoder(reader)
-
- // try to decode one token to see if stream is open
- _, err = decoder.Token()
- if err != nil {
- return nil, sitemap, err
- }
-
- // seek back to 0 if we are still here
- reader.Seek(0, 0)
- decoder = xml.NewDecoder(reader)
+ decoder := xml.NewDecoder(bytes.NewReader(xmlBody))
+ decoder.Strict = strict
+ var tok xml.Token
for {
- tok, err := decoder.Token()
- if err == io.EOF {
+ if strict {
+ tok, err = decoder.Token()
+ } else {
+ tok, err = decoder.RawToken()
+ }
+
+ if tok == nil && err == io.EOF {
+ // normal EOF
break
}
+
if err != nil {
- return nil, sitemap, err
+ // return URLs we got so far when error occurs
+ return URLs, sitemap, err
}
switch tok := tok.(type) {
@@ -52,7 +53,7 @@ func XML(resp *http.Response) (URLs []*url.URL, sitemap bool, err error) {
}
}
case xml.CharData:
- if strings.HasPrefix(string(tok), "http") {
+ if bytes.HasPrefix(tok, []byte("http")) {
parsedURL, err := url.Parse(string(tok))
if err == nil {
URLs = append(URLs, parsedURL)
diff --git a/internal/pkg/crawl/extractor/xml_test.go b/internal/pkg/crawl/extractor/xml_test.go
index 3dbb1cfa..5c975c86 100644
--- a/internal/pkg/crawl/extractor/xml_test.go
+++ b/internal/pkg/crawl/extractor/xml_test.go
@@ -2,21 +2,26 @@ package extractor
import (
"bytes"
+ "encoding/xml"
"io"
"net/http"
"net/url"
"os"
+ "strings"
"testing"
)
func TestXML(t *testing.T) {
tests := []struct {
- name string
- xmlBody string
- wantURLs []*url.URL
- wantURLsCount int
- wantErr bool
- sitemap bool
+ name string
+ xmlBody string
+ wantURLsLax []*url.URL
+ wantURLsStric []*url.URL
+ wantURLsCountLax int
+ wantURLsCountStric int
+ wantErrLax bool
+ wantErrStrict bool
+ sitemap bool
}{
{
name: "Valid XML with URLs",
@@ -28,26 +33,49 @@ func TestXML(t *testing.T) {
just some text
`,
- wantURLs: []*url.URL{
+ wantURLsLax: []*url.URL{
+ {Scheme: "http", Host: "example.com"},
+ {Scheme: "https", Host: "example.org"},
+ },
+ wantURLsStric: []*url.URL{
{Scheme: "http", Host: "example.com"},
{Scheme: "https", Host: "example.org"},
},
sitemap: false,
- wantErr: false,
},
{
- name: "Empty XML",
- xmlBody: ``,
- wantURLs: nil,
- wantErr: false,
- sitemap: false,
+ name: "unbalanced XML with URLs",
+ xmlBody: `
+
+ http://example.com
+
+ https://unclosed.example.com`,
+ wantURLsStric: []*url.URL{
+ {Scheme: "http", Host: "example.com"},
+ },
+ wantURLsLax: []*url.URL{
+ {Scheme: "http", Host: "example.com"},
+ {Scheme: "https", Host: "unclosed.example.com"},
+ },
+ wantErrStrict: true,
+ wantErrLax: false,
+ sitemap: false,
+ },
+ {
+ name: "Empty XML",
+ xmlBody: ``,
+ wantURLsStric: nil,
+ wantURLsLax: nil,
+ sitemap: false,
},
{
- name: "Invalid XML",
- xmlBody: ``,
- wantURLs: nil,
- wantErr: true,
- sitemap: false,
+ name: "alien XML",
+ xmlBody: `/>/,as;g^&R$W#Sf)(U>http://example.com
- not a valid url
`,
- wantURLs: []*url.URL{
+ wantURLsStric: []*url.URL{
{Scheme: "http", Host: "example.com"},
},
- wantErr: false,
- sitemap: false,
+ wantURLsLax: []*url.URL{
+ {Scheme: "http", Host: "example.com"},
+ },
+ wantErrStrict: false,
+ wantErrLax: false,
+ sitemap: false,
},
{
- name: "Huge sitemap",
- xmlBody: loadTestFile(t, "xml_test_sitemap.xml"),
- wantURLsCount: 100002,
- wantErr: false,
- sitemap: true,
+ name: "Huge sitemap",
+ xmlBody: loadTestFile(t, "xml_test_sitemap.xml"),
+ wantURLsCountStric: 100002,
+ wantURLsCountLax: 100002,
+ wantErrStrict: false,
+ wantErrLax: false,
+ sitemap: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- resp := &http.Response{
- Body: io.NopCloser(bytes.NewBufferString(tt.xmlBody)),
- }
-
- gotURLs, sitemap, err := XML(resp)
- if (err != nil) != tt.wantErr {
- t.Errorf("XML() error = %v, wantErr %v", err, tt.wantErr)
- return
- }
-
- if tt.wantURLsCount != 0 {
- if len(gotURLs) != tt.wantURLsCount {
- t.Errorf("XML() gotURLs count = %v, want %v", len(gotURLs), tt.wantURLsCount)
+ testMode := func(strict bool, wantErr bool, wantURLs []*url.URL, wantURLsCount int) {
+ resp := &http.Response{
+ Body: io.NopCloser(bytes.NewBufferString(tt.xmlBody)),
}
- }
-
- if tt.wantURLs != nil {
- if !compareURLs(gotURLs, tt.wantURLs) {
- t.Errorf("XML() gotURLs = %v, want %v", gotURLs, tt.wantURLs)
+ gotURLs, sitemap, err := XML(resp, strict)
+ if (err != nil) != wantErr {
+ t.Errorf("XML() strict = %v, error = %v, wantErr %v", strict, err, wantErr)
+ return
+ }
+ if wantURLsCount != 0 && len(gotURLs) != wantURLsCount {
+ t.Errorf("XML() strict = %v, gotURLs count = %v, want %v", strict, len(gotURLs), wantURLsCount)
+ }
+ if wantURLs != nil && !compareURLs(gotURLs, wantURLs) {
+ t.Errorf("XML() strict = %v, gotURLs = %v, want %v", strict, gotURLs, wantURLs)
+ }
+ if tt.sitemap != sitemap {
+ t.Errorf("XML() strict = %v, sitemap = %v, want %v", strict, sitemap, tt.sitemap)
}
}
- if tt.sitemap != sitemap {
- t.Errorf("XML() sitemap = %v, want %v", sitemap, tt.sitemap)
- }
+ // Strict mode
+ testMode(true, tt.wantErrStrict, tt.wantURLsStric, tt.wantURLsCountStric)
+
+ // Lax mode
+ testMode(false, tt.wantErrLax, tt.wantURLsLax, tt.wantURLsCountLax)
})
}
}
@@ -116,14 +150,32 @@ func loadTestFile(t *testing.T, path string) string {
return string(b)
}
-func TestXMLBodyReadError(t *testing.T) {
+func TestXMLBodySyntaxEOFErrorStrict(t *testing.T) {
+ wantErr := xml.SyntaxError{Line: 3, Msg: "unexpected EOF"}
resp := &http.Response{
- Body: io.NopCloser(bytes.NewReader([]byte{})), // Empty reader to simulate EOF
+ Body: io.NopCloser(strings.NewReader(
+ `
+
+ `)),
}
- resp.Body.Close() // Close the body to simulate a read error
-
- _, _, err := XML(resp)
+ _, _, err := XML(resp, true)
if err == nil {
- t.Errorf("XML() expected error, got nil")
+ t.Errorf("XML() error = %v, wantErr %v", err, wantErr)
+ return
+ }
+ if err.Error() != wantErr.Error() {
+ t.Errorf("XML() error = %v, wantErr %v", err, wantErr)
+ }
+}
+
+func TestXMLBodySyntaxEOFErrorLax(t *testing.T) {
+ resp := &http.Response{
+ Body: io.NopCloser(strings.NewReader(`
+
+ `)),
+ }
+ _, _, err := XML(resp, false)
+ if err != nil {
+ t.Errorf("XML() error = %v, wantErr nil", err)
}
}