Path: blob/main/vendor/github.com/onsi/gomega/internal/async_assertion.go
2880 views
package internal12import (3"context"4"errors"5"fmt"6"reflect"7"runtime"8"sync"9"time"1011"github.com/onsi/gomega/format"12"github.com/onsi/gomega/types"13)1415var errInterface = reflect.TypeOf((*error)(nil)).Elem()16var gomegaType = reflect.TypeOf((*types.Gomega)(nil)).Elem()17var contextType = reflect.TypeOf(new(context.Context)).Elem()1819type formattedGomegaError interface {20FormattedGomegaError() string21}2223type asyncPolledActualError struct {24message string25}2627func (err *asyncPolledActualError) Error() string {28return err.message29}3031func (err *asyncPolledActualError) FormattedGomegaError() string {32return err.message33}3435type contextWithAttachProgressReporter interface {36AttachProgressReporter(func() string) func()37}3839type asyncGomegaHaltExecutionError struct{}4041func (a asyncGomegaHaltExecutionError) GinkgoRecoverShouldIgnoreThisPanic() {}42func (a asyncGomegaHaltExecutionError) Error() string {43return `An assertion has failed in a goroutine. You should call4445defer GinkgoRecover()4647at the top of the goroutine that caused this panic. This will allow Ginkgo and Gomega to correctly capture and manage this panic.`48}4950type AsyncAssertionType uint5152const (53AsyncAssertionTypeEventually AsyncAssertionType = iota54AsyncAssertionTypeConsistently55)5657func (at AsyncAssertionType) String() string {58switch at {59case AsyncAssertionTypeEventually:60return "Eventually"61case AsyncAssertionTypeConsistently:62return "Consistently"63}64return "INVALID ASYNC ASSERTION TYPE"65}6667type AsyncAssertion struct {68asyncType AsyncAssertionType6970actualIsFunc bool71actual any72argsToForward []any7374timeoutInterval time.Duration75pollingInterval time.Duration76mustPassRepeatedly int77ctx context.Context78offset int79g *Gomega80}8182func NewAsyncAssertion(asyncType AsyncAssertionType, actualInput any, g *Gomega, timeoutInterval time.Duration, pollingInterval time.Duration, mustPassRepeatedly int, ctx context.Context, offset int) *AsyncAssertion {83out := &AsyncAssertion{84asyncType: asyncType,85timeoutInterval: timeoutInterval,86pollingInterval: pollingInterval,87mustPassRepeatedly: mustPassRepeatedly,88offset: offset,89ctx: ctx,90g: g,91}9293out.actual = actualInput94if actualInput != nil && reflect.TypeOf(actualInput).Kind() == reflect.Func {95out.actualIsFunc = true96}9798return out99}100101func (assertion *AsyncAssertion) WithOffset(offset int) types.AsyncAssertion {102assertion.offset = offset103return assertion104}105106func (assertion *AsyncAssertion) WithTimeout(interval time.Duration) types.AsyncAssertion {107assertion.timeoutInterval = interval108return assertion109}110111func (assertion *AsyncAssertion) WithPolling(interval time.Duration) types.AsyncAssertion {112assertion.pollingInterval = interval113return assertion114}115116func (assertion *AsyncAssertion) Within(timeout time.Duration) types.AsyncAssertion {117assertion.timeoutInterval = timeout118return assertion119}120121func (assertion *AsyncAssertion) ProbeEvery(interval time.Duration) types.AsyncAssertion {122assertion.pollingInterval = interval123return assertion124}125126func (assertion *AsyncAssertion) WithContext(ctx context.Context) types.AsyncAssertion {127assertion.ctx = ctx128return assertion129}130131func (assertion *AsyncAssertion) WithArguments(argsToForward ...any) types.AsyncAssertion {132assertion.argsToForward = argsToForward133return assertion134}135136func (assertion *AsyncAssertion) MustPassRepeatedly(count int) types.AsyncAssertion {137assertion.mustPassRepeatedly = count138return assertion139}140141func (assertion *AsyncAssertion) Should(matcher types.GomegaMatcher, optionalDescription ...any) bool {142assertion.g.THelper()143vetOptionalDescription("Asynchronous assertion", optionalDescription...)144return assertion.match(matcher, true, optionalDescription...)145}146147func (assertion *AsyncAssertion) To(matcher types.GomegaMatcher, optionalDescription ...any) bool {148return assertion.Should(matcher, optionalDescription...)149}150151func (assertion *AsyncAssertion) ShouldNot(matcher types.GomegaMatcher, optionalDescription ...any) bool {152assertion.g.THelper()153vetOptionalDescription("Asynchronous assertion", optionalDescription...)154return assertion.match(matcher, false, optionalDescription...)155}156157func (assertion *AsyncAssertion) ToNot(matcher types.GomegaMatcher, optionalDescription ...any) bool {158return assertion.ShouldNot(matcher, optionalDescription...)159}160161func (assertion *AsyncAssertion) NotTo(matcher types.GomegaMatcher, optionalDescription ...any) bool {162return assertion.ShouldNot(matcher, optionalDescription...)163}164165func (assertion *AsyncAssertion) buildDescription(optionalDescription ...any) string {166switch len(optionalDescription) {167case 0:168return ""169case 1:170if describe, ok := optionalDescription[0].(func() string); ok {171return describe() + "\n"172}173}174return fmt.Sprintf(optionalDescription[0].(string), optionalDescription[1:]...) + "\n"175}176177func (assertion *AsyncAssertion) processReturnValues(values []reflect.Value) (any, error) {178if len(values) == 0 {179return nil, &asyncPolledActualError{180message: fmt.Sprintf("The function passed to %s did not return any values", assertion.asyncType),181}182}183184actual := values[0].Interface()185if _, ok := AsPollingSignalError(actual); ok {186return actual, actual.(error)187}188189var err error190for i, extraValue := range values[1:] {191extra := extraValue.Interface()192if extra == nil {193continue194}195if _, ok := AsPollingSignalError(extra); ok {196return actual, extra.(error)197}198extraType := reflect.TypeOf(extra)199zero := reflect.Zero(extraType).Interface()200if reflect.DeepEqual(extra, zero) {201continue202}203if i == len(values)-2 && extraType.Implements(errInterface) {204err = extra.(error)205}206if err == nil {207err = &asyncPolledActualError{208message: fmt.Sprintf("The function passed to %s had an unexpected non-nil/non-zero return value at index %d:\n%s", assertion.asyncType, i+1, format.Object(extra, 1)),209}210}211}212213return actual, err214}215216func (assertion *AsyncAssertion) invalidFunctionError(t reflect.Type) error {217return fmt.Errorf(`The function passed to %s had an invalid signature of %s. Functions passed to %s must either:218219(a) have return values or220(b) take a Gomega interface as their first argument and use that Gomega instance to make assertions.221222You can learn more at https://onsi.github.io/gomega/#eventually223`, assertion.asyncType, t, assertion.asyncType)224}225226func (assertion *AsyncAssertion) noConfiguredContextForFunctionError() error {227return fmt.Errorf(`The function passed to %s requested a context.Context, but no context has been provided. Please pass one in using %s().WithContext().228229You can learn more at https://onsi.github.io/gomega/#eventually230`, assertion.asyncType, assertion.asyncType)231}232233func (assertion *AsyncAssertion) argumentMismatchError(t reflect.Type, numProvided int) error {234have := "have"235if numProvided == 1 {236have = "has"237}238return fmt.Errorf(`The function passed to %s has signature %s takes %d arguments but %d %s been provided. Please use %s().WithArguments() to pass the correct set of arguments.239240You can learn more at https://onsi.github.io/gomega/#eventually241`, assertion.asyncType, t, t.NumIn(), numProvided, have, assertion.asyncType)242}243244func (assertion *AsyncAssertion) invalidMustPassRepeatedlyError(reason string) error {245return fmt.Errorf(`Invalid use of MustPassRepeatedly with %s %s246247You can learn more at https://onsi.github.io/gomega/#eventually248`, assertion.asyncType, reason)249}250251func (assertion *AsyncAssertion) buildActualPoller() (func() (any, error), error) {252if !assertion.actualIsFunc {253return func() (any, error) { return assertion.actual, nil }, nil254}255actualValue := reflect.ValueOf(assertion.actual)256actualType := reflect.TypeOf(assertion.actual)257numIn, numOut, isVariadic := actualType.NumIn(), actualType.NumOut(), actualType.IsVariadic()258259if numIn == 0 && numOut == 0 {260return nil, assertion.invalidFunctionError(actualType)261}262takesGomega, takesContext := false, false263if numIn > 0 {264takesGomega, takesContext = actualType.In(0).Implements(gomegaType), actualType.In(0).Implements(contextType)265}266if takesGomega && numIn > 1 && actualType.In(1).Implements(contextType) {267takesContext = true268}269if takesContext && len(assertion.argsToForward) > 0 && reflect.TypeOf(assertion.argsToForward[0]).Implements(contextType) {270takesContext = false271}272if !takesGomega && numOut == 0 {273return nil, assertion.invalidFunctionError(actualType)274}275if takesContext && assertion.ctx == nil {276return nil, assertion.noConfiguredContextForFunctionError()277}278279var assertionFailure error280inValues := []reflect.Value{}281if takesGomega {282inValues = append(inValues, reflect.ValueOf(NewGomega(assertion.g.DurationBundle).ConfigureWithFailHandler(func(message string, callerSkip ...int) {283skip := 0284if len(callerSkip) > 0 {285skip = callerSkip[0]286}287_, file, line, _ := runtime.Caller(skip + 1)288assertionFailure = &asyncPolledActualError{289message: fmt.Sprintf("The function passed to %s failed at %s:%d with:\n%s", assertion.asyncType, file, line, message),290}291// we throw an asyncGomegaHaltExecutionError so that defer GinkgoRecover() can catch this error if the user makes an assertion in a goroutine292panic(asyncGomegaHaltExecutionError{})293})))294}295if takesContext {296inValues = append(inValues, reflect.ValueOf(assertion.ctx))297}298for _, arg := range assertion.argsToForward {299inValues = append(inValues, reflect.ValueOf(arg))300}301302if !isVariadic && numIn != len(inValues) {303return nil, assertion.argumentMismatchError(actualType, len(inValues))304} else if isVariadic && len(inValues) < numIn-1 {305return nil, assertion.argumentMismatchError(actualType, len(inValues))306}307308if assertion.mustPassRepeatedly != 1 && assertion.asyncType != AsyncAssertionTypeEventually {309return nil, assertion.invalidMustPassRepeatedlyError("it can only be used with Eventually")310}311if assertion.mustPassRepeatedly < 1 {312return nil, assertion.invalidMustPassRepeatedlyError("parameter can't be < 1")313}314315return func() (actual any, err error) {316var values []reflect.Value317assertionFailure = nil318defer func() {319if numOut == 0 && takesGomega {320actual = assertionFailure321} else {322actual, err = assertion.processReturnValues(values)323_, isAsyncError := AsPollingSignalError(err)324if assertionFailure != nil && !isAsyncError {325err = assertionFailure326}327}328if e := recover(); e != nil {329if _, isAsyncError := AsPollingSignalError(e); isAsyncError {330err = e.(error)331} else if assertionFailure == nil {332panic(e)333}334}335}()336values = actualValue.Call(inValues)337return338}, nil339}340341func (assertion *AsyncAssertion) afterTimeout() <-chan time.Time {342if assertion.timeoutInterval >= 0 {343return time.After(assertion.timeoutInterval)344}345346if assertion.asyncType == AsyncAssertionTypeConsistently {347return time.After(assertion.g.DurationBundle.ConsistentlyDuration)348} else {349if assertion.ctx == nil || assertion.g.DurationBundle.EnforceDefaultTimeoutsWhenUsingContexts {350return time.After(assertion.g.DurationBundle.EventuallyTimeout)351} else {352return nil353}354}355}356357func (assertion *AsyncAssertion) afterPolling() <-chan time.Time {358if assertion.pollingInterval >= 0 {359return time.After(assertion.pollingInterval)360}361if assertion.asyncType == AsyncAssertionTypeConsistently {362return time.After(assertion.g.DurationBundle.ConsistentlyPollingInterval)363} else {364return time.After(assertion.g.DurationBundle.EventuallyPollingInterval)365}366}367368func (assertion *AsyncAssertion) matcherSaysStopTrying(matcher types.GomegaMatcher, value any) bool {369if assertion.actualIsFunc || types.MatchMayChangeInTheFuture(matcher, value) {370return false371}372return true373}374375func (assertion *AsyncAssertion) pollMatcher(matcher types.GomegaMatcher, value any) (matches bool, err error) {376defer func() {377if e := recover(); e != nil {378if _, isAsyncError := AsPollingSignalError(e); isAsyncError {379err = e.(error)380} else {381panic(e)382}383}384}()385386matches, err = matcher.Match(value)387388return389}390391func (assertion *AsyncAssertion) match(matcher types.GomegaMatcher, desiredMatch bool, optionalDescription ...any) bool {392timer := time.Now()393timeout := assertion.afterTimeout()394lock := sync.Mutex{}395396var matches, hasLastValidActual bool397var actual, lastValidActual any398var actualErr, matcherErr error399var oracleMatcherSaysStop bool400401assertion.g.THelper()402403pollActual, buildActualPollerErr := assertion.buildActualPoller()404if buildActualPollerErr != nil {405assertion.g.Fail(buildActualPollerErr.Error(), 2+assertion.offset)406return false407}408409actual, actualErr = pollActual()410if actualErr == nil {411lastValidActual = actual412hasLastValidActual = true413oracleMatcherSaysStop = assertion.matcherSaysStopTrying(matcher, actual)414matches, matcherErr = assertion.pollMatcher(matcher, actual)415}416417renderError := func(preamble string, err error) string {418message := ""419if pollingSignalErr, ok := AsPollingSignalError(err); ok {420message = err.Error()421for _, attachment := range pollingSignalErr.Attachments {422message += fmt.Sprintf("\n%s:\n", attachment.Description)423message += format.Object(attachment.Object, 1)424}425} else {426message = preamble + "\n" + format.Object(err, 1)427}428return message429}430431messageGenerator := func() string {432// can be called out of band by Ginkgo if the user requests a progress report433lock.Lock()434defer lock.Unlock()435message := ""436437if actualErr == nil {438if matcherErr == nil {439if desiredMatch != matches {440if desiredMatch {441message += matcher.FailureMessage(actual)442} else {443message += matcher.NegatedFailureMessage(actual)444}445} else {446if assertion.asyncType == AsyncAssertionTypeConsistently {447message += "There is no failure as the matcher passed to Consistently has not yet failed"448} else {449message += "There is no failure as the matcher passed to Eventually succeeded on its most recent iteration"450}451}452} else {453var fgErr formattedGomegaError454if errors.As(matcherErr, &fgErr) {455message += fgErr.FormattedGomegaError() + "\n"456} else {457message += renderError(fmt.Sprintf("The matcher passed to %s returned the following error:", assertion.asyncType), matcherErr)458}459}460} else {461var fgErr formattedGomegaError462if errors.As(actualErr, &fgErr) {463message += fgErr.FormattedGomegaError() + "\n"464} else {465message += renderError(fmt.Sprintf("The function passed to %s returned the following error:", assertion.asyncType), actualErr)466}467if hasLastValidActual {468message += fmt.Sprintf("\nAt one point, however, the function did return successfully.\nYet, %s failed because", assertion.asyncType)469_, e := matcher.Match(lastValidActual)470if e != nil {471message += renderError(" the matcher returned the following error:", e)472} else {473message += " the matcher was not satisfied:\n"474if desiredMatch {475message += matcher.FailureMessage(lastValidActual)476} else {477message += matcher.NegatedFailureMessage(lastValidActual)478}479}480}481}482483description := assertion.buildDescription(optionalDescription...)484return fmt.Sprintf("%s%s", description, message)485}486487fail := func(preamble string) {488assertion.g.THelper()489assertion.g.Fail(fmt.Sprintf("%s after %.3fs.\n%s", preamble, time.Since(timer).Seconds(), messageGenerator()), 3+assertion.offset)490}491492var contextDone <-chan struct{}493if assertion.ctx != nil {494contextDone = assertion.ctx.Done()495if v, ok := assertion.ctx.Value("GINKGO_SPEC_CONTEXT").(contextWithAttachProgressReporter); ok {496detach := v.AttachProgressReporter(messageGenerator)497defer detach()498}499}500501// Used to count the number of times in a row a step passed502passedRepeatedlyCount := 0503for {504var nextPoll <-chan time.Time = nil505var isTryAgainAfterError = false506507for _, err := range []error{actualErr, matcherErr} {508if pollingSignalErr, ok := AsPollingSignalError(err); ok {509if pollingSignalErr.IsStopTrying() {510if pollingSignalErr.IsSuccessful() {511if assertion.asyncType == AsyncAssertionTypeEventually {512fail("Told to stop trying (and ignoring call to Successfully(), as it is only relevant with Consistently)")513} else {514return true // early escape hatch for Consistently515}516} else {517fail("Told to stop trying")518}519return false520}521if pollingSignalErr.IsTryAgainAfter() {522nextPoll = time.After(pollingSignalErr.TryAgainDuration())523isTryAgainAfterError = true524}525}526}527528if actualErr == nil && matcherErr == nil && matches == desiredMatch {529if assertion.asyncType == AsyncAssertionTypeEventually {530passedRepeatedlyCount += 1531if passedRepeatedlyCount == assertion.mustPassRepeatedly {532return true533}534}535} else if !isTryAgainAfterError {536if assertion.asyncType == AsyncAssertionTypeConsistently {537fail("Failed")538return false539}540// Reset the consecutive pass count541passedRepeatedlyCount = 0542}543544if oracleMatcherSaysStop {545if assertion.asyncType == AsyncAssertionTypeEventually {546fail("No future change is possible. Bailing out early")547return false548} else {549return true550}551}552553if nextPoll == nil {554nextPoll = assertion.afterPolling()555}556557select {558case <-nextPoll:559a, e := pollActual()560lock.Lock()561actual, actualErr = a, e562lock.Unlock()563if actualErr == nil {564lock.Lock()565lastValidActual = actual566hasLastValidActual = true567lock.Unlock()568oracleMatcherSaysStop = assertion.matcherSaysStopTrying(matcher, actual)569m, e := assertion.pollMatcher(matcher, actual)570lock.Lock()571matches, matcherErr = m, e572lock.Unlock()573}574case <-contextDone:575err := context.Cause(assertion.ctx)576if err != nil && err != context.Canceled {577fail(fmt.Sprintf("Context was cancelled (cause: %s)", err))578} else {579fail("Context was cancelled")580}581return false582case <-timeout:583if assertion.asyncType == AsyncAssertionTypeEventually {584fail("Timed out")585return false586} else {587if isTryAgainAfterError {588fail("Timed out while waiting on TryAgainAfter")589return false590}591return true592}593}594}595}596597598