diff --git a/gorequest.go b/gorequest.go index ea1a55c..635b451 100644 --- a/gorequest.go +++ b/gorequest.go @@ -947,11 +947,15 @@ func (s *SuperAgent) EndBytes(callback ...func(response Response, body []byte, e for { resp, body, errs = s.getResponseBytes() if errs != nil { - return nil, nil, errs - } - if s.isRetryableRequest(resp) { - resp.Header.Set("Retry-Count", strconv.Itoa(s.Retryable.Attempt)) - break + if s.isErrNotRetryableRequest(errs) { + errs = append(errs, fmt.Errorf("retry attempt: %d", s.Retryable.Attempt)) + return nil, nil, errs + } + } else { + if s.isRetryableRequest(resp, errs) { + resp.Header.Set("Retry-Count", strconv.Itoa(s.Retryable.Attempt)) + break + } } } @@ -962,8 +966,8 @@ func (s *SuperAgent) EndBytes(callback ...func(response Response, body []byte, e return resp, body, nil } -func (s *SuperAgent) isRetryableRequest(resp Response) bool { - if s.Retryable.Enable && s.Retryable.Attempt < s.Retryable.RetryerCount && contains(resp.StatusCode, s.Retryable.RetryableStatus) { +func (s *SuperAgent) isRetryableRequest(resp Response, errs []error) bool { + if s.Retryable.Enable && s.Retryable.Attempt < s.Retryable.RetryerCount && (contains(resp.StatusCode, s.Retryable.RetryableStatus)) { time.Sleep(s.Retryable.RetryerTime) s.Retryable.Attempt++ return false @@ -971,6 +975,16 @@ func (s *SuperAgent) isRetryableRequest(resp Response) bool { return true } +func (s *SuperAgent) isErrNotRetryableRequest(errs []error) bool { + if s.Retryable.Enable && s.Retryable.Attempt < s.Retryable.RetryerCount && IsTimeout(errs) { + time.Sleep(s.Retryable.RetryerTime) + s.Retryable.Attempt++ + s.Errors = nil + return false + } + return true +} + func contains(respStatus int, statuses []int) bool { for _, status := range statuses { if status == respStatus { @@ -980,6 +994,18 @@ func contains(respStatus int, statuses []int) bool { return false } +// IsTimeout checks is it a net timeout error. +// Applicable only for go version >1.6 +func IsTimeout(errs []error) bool { + for _, v := range errs { + if err, ok := v.(net.Error); ok && err.Timeout() { + return true + } + } + + return false +} + // EndStruct should be used when you want the body as a struct. The callbacks work the same way as with `End`, except that a struct is used instead of a string. func (s *SuperAgent) EndStruct(v interface{}, callback ...func(response Response, v interface{}, body []byte, errs []error)) (Response, []byte, []error) { resp, body, errs := s.EndBytes() diff --git a/gorequest_test.go b/gorequest_test.go index 2cf5a74..83df44d 100644 --- a/gorequest_test.go +++ b/gorequest_test.go @@ -203,9 +203,12 @@ func TestGet(t *testing.T) { // testing for Get method with retry option func TestRetryGet(t *testing.T) { const ( - case1_empty = "/" - case24_after_3_attempt_return_valid = "/retry_3_attempt_then_valid" - retry_count_expected = "3" + case1_empty = "/" + case24_after_3_attempt_return_valid = "/retry_3_attempt_then_valid" + retry_count_expected = "3" + casetimeout_after_3_attempt_return_valid = "/timeout_after_3_attempt_return_valid" + timeout_retry_expected = "retry attempt: 3" + casetimeout_retry_3_attempt = "/imeout_retry_3_attempt" ) var attempt int @@ -235,6 +238,16 @@ func TestRetryGet(t *testing.T) { t.Logf("case %v ", case24_after_3_attempt_return_valid) } attempt++ + case casetimeout_retry_3_attempt: + time.Sleep(5 * time.Nanosecond) + attempt++ + case casetimeout_after_3_attempt_return_valid: + if attempt == 4 { + w.WriteHeader(200) + } else { + time.Sleep(2 * time.Second) + } + attempt++ } })) @@ -264,6 +277,33 @@ func TestRetryGet(t *testing.T) { if retryCountReturn != retry_count_expected { t.Errorf("Expected [%s] retry but was [%s]", retry_count_expected, retryCountReturn) } + + // Timeout retry 3 times + _, _, errs = New().Get(ts.URL+casetimeout_retry_3_attempt). + Timeout(1*time.Nanosecond). + Retry(3, 1*time.Nanosecond, http.StatusBadRequest). + End() + if errs != nil { + lastErr := errs[len(errs)-1] + if lastErr.Error() != timeout_retry_expected { + t.Errorf("Expected [%s] retry but was [%s]", timeout_retry_expected, lastErr) + } + } else { + t.Errorf("No testing for this case yet : %q", errs) + } + + // Timeout, after 3 attempt valid + resp, _, errs = New().Get(ts.URL+casetimeout_after_3_attempt_return_valid). + Timeout(1*time.Second). + Retry(3, 1*time.Second). + End() + if errs != nil { + t.Errorf("No testing for this case yet : %v", errs) + } else { + if resp.StatusCode != 200 { + t.Errorf("Expected [%d] but was [%d]", resp.StatusCode, http.StatusOK) + } + } } // testing for Options method