Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
kardolus
GitHub Repository: kardolus/chatgpt-cli
Path: blob/main/vendor/github.com/onsi/gomega/internal/async_assertion.go
2880 views
1
package internal
2
3
import (
4
"context"
5
"errors"
6
"fmt"
7
"reflect"
8
"runtime"
9
"sync"
10
"time"
11
12
"github.com/onsi/gomega/format"
13
"github.com/onsi/gomega/types"
14
)
15
16
var errInterface = reflect.TypeOf((*error)(nil)).Elem()
17
var gomegaType = reflect.TypeOf((*types.Gomega)(nil)).Elem()
18
var contextType = reflect.TypeOf(new(context.Context)).Elem()
19
20
type formattedGomegaError interface {
21
FormattedGomegaError() string
22
}
23
24
type asyncPolledActualError struct {
25
message string
26
}
27
28
func (err *asyncPolledActualError) Error() string {
29
return err.message
30
}
31
32
func (err *asyncPolledActualError) FormattedGomegaError() string {
33
return err.message
34
}
35
36
type contextWithAttachProgressReporter interface {
37
AttachProgressReporter(func() string) func()
38
}
39
40
type asyncGomegaHaltExecutionError struct{}
41
42
func (a asyncGomegaHaltExecutionError) GinkgoRecoverShouldIgnoreThisPanic() {}
43
func (a asyncGomegaHaltExecutionError) Error() string {
44
return `An assertion has failed in a goroutine. You should call
45
46
defer GinkgoRecover()
47
48
at the top of the goroutine that caused this panic. This will allow Ginkgo and Gomega to correctly capture and manage this panic.`
49
}
50
51
type AsyncAssertionType uint
52
53
const (
54
AsyncAssertionTypeEventually AsyncAssertionType = iota
55
AsyncAssertionTypeConsistently
56
)
57
58
func (at AsyncAssertionType) String() string {
59
switch at {
60
case AsyncAssertionTypeEventually:
61
return "Eventually"
62
case AsyncAssertionTypeConsistently:
63
return "Consistently"
64
}
65
return "INVALID ASYNC ASSERTION TYPE"
66
}
67
68
type AsyncAssertion struct {
69
asyncType AsyncAssertionType
70
71
actualIsFunc bool
72
actual any
73
argsToForward []any
74
75
timeoutInterval time.Duration
76
pollingInterval time.Duration
77
mustPassRepeatedly int
78
ctx context.Context
79
offset int
80
g *Gomega
81
}
82
83
func NewAsyncAssertion(asyncType AsyncAssertionType, actualInput any, g *Gomega, timeoutInterval time.Duration, pollingInterval time.Duration, mustPassRepeatedly int, ctx context.Context, offset int) *AsyncAssertion {
84
out := &AsyncAssertion{
85
asyncType: asyncType,
86
timeoutInterval: timeoutInterval,
87
pollingInterval: pollingInterval,
88
mustPassRepeatedly: mustPassRepeatedly,
89
offset: offset,
90
ctx: ctx,
91
g: g,
92
}
93
94
out.actual = actualInput
95
if actualInput != nil && reflect.TypeOf(actualInput).Kind() == reflect.Func {
96
out.actualIsFunc = true
97
}
98
99
return out
100
}
101
102
func (assertion *AsyncAssertion) WithOffset(offset int) types.AsyncAssertion {
103
assertion.offset = offset
104
return assertion
105
}
106
107
func (assertion *AsyncAssertion) WithTimeout(interval time.Duration) types.AsyncAssertion {
108
assertion.timeoutInterval = interval
109
return assertion
110
}
111
112
func (assertion *AsyncAssertion) WithPolling(interval time.Duration) types.AsyncAssertion {
113
assertion.pollingInterval = interval
114
return assertion
115
}
116
117
func (assertion *AsyncAssertion) Within(timeout time.Duration) types.AsyncAssertion {
118
assertion.timeoutInterval = timeout
119
return assertion
120
}
121
122
func (assertion *AsyncAssertion) ProbeEvery(interval time.Duration) types.AsyncAssertion {
123
assertion.pollingInterval = interval
124
return assertion
125
}
126
127
func (assertion *AsyncAssertion) WithContext(ctx context.Context) types.AsyncAssertion {
128
assertion.ctx = ctx
129
return assertion
130
}
131
132
func (assertion *AsyncAssertion) WithArguments(argsToForward ...any) types.AsyncAssertion {
133
assertion.argsToForward = argsToForward
134
return assertion
135
}
136
137
func (assertion *AsyncAssertion) MustPassRepeatedly(count int) types.AsyncAssertion {
138
assertion.mustPassRepeatedly = count
139
return assertion
140
}
141
142
func (assertion *AsyncAssertion) Should(matcher types.GomegaMatcher, optionalDescription ...any) bool {
143
assertion.g.THelper()
144
vetOptionalDescription("Asynchronous assertion", optionalDescription...)
145
return assertion.match(matcher, true, optionalDescription...)
146
}
147
148
func (assertion *AsyncAssertion) To(matcher types.GomegaMatcher, optionalDescription ...any) bool {
149
return assertion.Should(matcher, optionalDescription...)
150
}
151
152
func (assertion *AsyncAssertion) ShouldNot(matcher types.GomegaMatcher, optionalDescription ...any) bool {
153
assertion.g.THelper()
154
vetOptionalDescription("Asynchronous assertion", optionalDescription...)
155
return assertion.match(matcher, false, optionalDescription...)
156
}
157
158
func (assertion *AsyncAssertion) ToNot(matcher types.GomegaMatcher, optionalDescription ...any) bool {
159
return assertion.ShouldNot(matcher, optionalDescription...)
160
}
161
162
func (assertion *AsyncAssertion) NotTo(matcher types.GomegaMatcher, optionalDescription ...any) bool {
163
return assertion.ShouldNot(matcher, optionalDescription...)
164
}
165
166
func (assertion *AsyncAssertion) buildDescription(optionalDescription ...any) string {
167
switch len(optionalDescription) {
168
case 0:
169
return ""
170
case 1:
171
if describe, ok := optionalDescription[0].(func() string); ok {
172
return describe() + "\n"
173
}
174
}
175
return fmt.Sprintf(optionalDescription[0].(string), optionalDescription[1:]...) + "\n"
176
}
177
178
func (assertion *AsyncAssertion) processReturnValues(values []reflect.Value) (any, error) {
179
if len(values) == 0 {
180
return nil, &asyncPolledActualError{
181
message: fmt.Sprintf("The function passed to %s did not return any values", assertion.asyncType),
182
}
183
}
184
185
actual := values[0].Interface()
186
if _, ok := AsPollingSignalError(actual); ok {
187
return actual, actual.(error)
188
}
189
190
var err error
191
for i, extraValue := range values[1:] {
192
extra := extraValue.Interface()
193
if extra == nil {
194
continue
195
}
196
if _, ok := AsPollingSignalError(extra); ok {
197
return actual, extra.(error)
198
}
199
extraType := reflect.TypeOf(extra)
200
zero := reflect.Zero(extraType).Interface()
201
if reflect.DeepEqual(extra, zero) {
202
continue
203
}
204
if i == len(values)-2 && extraType.Implements(errInterface) {
205
err = extra.(error)
206
}
207
if err == nil {
208
err = &asyncPolledActualError{
209
message: fmt.Sprintf("The function passed to %s had an unexpected non-nil/non-zero return value at index %d:\n%s", assertion.asyncType, i+1, format.Object(extra, 1)),
210
}
211
}
212
}
213
214
return actual, err
215
}
216
217
func (assertion *AsyncAssertion) invalidFunctionError(t reflect.Type) error {
218
return fmt.Errorf(`The function passed to %s had an invalid signature of %s. Functions passed to %s must either:
219
220
(a) have return values or
221
(b) take a Gomega interface as their first argument and use that Gomega instance to make assertions.
222
223
You can learn more at https://onsi.github.io/gomega/#eventually
224
`, assertion.asyncType, t, assertion.asyncType)
225
}
226
227
func (assertion *AsyncAssertion) noConfiguredContextForFunctionError() error {
228
return fmt.Errorf(`The function passed to %s requested a context.Context, but no context has been provided. Please pass one in using %s().WithContext().
229
230
You can learn more at https://onsi.github.io/gomega/#eventually
231
`, assertion.asyncType, assertion.asyncType)
232
}
233
234
func (assertion *AsyncAssertion) argumentMismatchError(t reflect.Type, numProvided int) error {
235
have := "have"
236
if numProvided == 1 {
237
have = "has"
238
}
239
return fmt.Errorf(`The function passed to %s has signature %s takes %d arguments but %d %s been provided. Please use %s().WithArguments() to pass the correct set of arguments.
240
241
You can learn more at https://onsi.github.io/gomega/#eventually
242
`, assertion.asyncType, t, t.NumIn(), numProvided, have, assertion.asyncType)
243
}
244
245
func (assertion *AsyncAssertion) invalidMustPassRepeatedlyError(reason string) error {
246
return fmt.Errorf(`Invalid use of MustPassRepeatedly with %s %s
247
248
You can learn more at https://onsi.github.io/gomega/#eventually
249
`, assertion.asyncType, reason)
250
}
251
252
func (assertion *AsyncAssertion) buildActualPoller() (func() (any, error), error) {
253
if !assertion.actualIsFunc {
254
return func() (any, error) { return assertion.actual, nil }, nil
255
}
256
actualValue := reflect.ValueOf(assertion.actual)
257
actualType := reflect.TypeOf(assertion.actual)
258
numIn, numOut, isVariadic := actualType.NumIn(), actualType.NumOut(), actualType.IsVariadic()
259
260
if numIn == 0 && numOut == 0 {
261
return nil, assertion.invalidFunctionError(actualType)
262
}
263
takesGomega, takesContext := false, false
264
if numIn > 0 {
265
takesGomega, takesContext = actualType.In(0).Implements(gomegaType), actualType.In(0).Implements(contextType)
266
}
267
if takesGomega && numIn > 1 && actualType.In(1).Implements(contextType) {
268
takesContext = true
269
}
270
if takesContext && len(assertion.argsToForward) > 0 && reflect.TypeOf(assertion.argsToForward[0]).Implements(contextType) {
271
takesContext = false
272
}
273
if !takesGomega && numOut == 0 {
274
return nil, assertion.invalidFunctionError(actualType)
275
}
276
if takesContext && assertion.ctx == nil {
277
return nil, assertion.noConfiguredContextForFunctionError()
278
}
279
280
var assertionFailure error
281
inValues := []reflect.Value{}
282
if takesGomega {
283
inValues = append(inValues, reflect.ValueOf(NewGomega(assertion.g.DurationBundle).ConfigureWithFailHandler(func(message string, callerSkip ...int) {
284
skip := 0
285
if len(callerSkip) > 0 {
286
skip = callerSkip[0]
287
}
288
_, file, line, _ := runtime.Caller(skip + 1)
289
assertionFailure = &asyncPolledActualError{
290
message: fmt.Sprintf("The function passed to %s failed at %s:%d with:\n%s", assertion.asyncType, file, line, message),
291
}
292
// we throw an asyncGomegaHaltExecutionError so that defer GinkgoRecover() can catch this error if the user makes an assertion in a goroutine
293
panic(asyncGomegaHaltExecutionError{})
294
})))
295
}
296
if takesContext {
297
inValues = append(inValues, reflect.ValueOf(assertion.ctx))
298
}
299
for _, arg := range assertion.argsToForward {
300
inValues = append(inValues, reflect.ValueOf(arg))
301
}
302
303
if !isVariadic && numIn != len(inValues) {
304
return nil, assertion.argumentMismatchError(actualType, len(inValues))
305
} else if isVariadic && len(inValues) < numIn-1 {
306
return nil, assertion.argumentMismatchError(actualType, len(inValues))
307
}
308
309
if assertion.mustPassRepeatedly != 1 && assertion.asyncType != AsyncAssertionTypeEventually {
310
return nil, assertion.invalidMustPassRepeatedlyError("it can only be used with Eventually")
311
}
312
if assertion.mustPassRepeatedly < 1 {
313
return nil, assertion.invalidMustPassRepeatedlyError("parameter can't be < 1")
314
}
315
316
return func() (actual any, err error) {
317
var values []reflect.Value
318
assertionFailure = nil
319
defer func() {
320
if numOut == 0 && takesGomega {
321
actual = assertionFailure
322
} else {
323
actual, err = assertion.processReturnValues(values)
324
_, isAsyncError := AsPollingSignalError(err)
325
if assertionFailure != nil && !isAsyncError {
326
err = assertionFailure
327
}
328
}
329
if e := recover(); e != nil {
330
if _, isAsyncError := AsPollingSignalError(e); isAsyncError {
331
err = e.(error)
332
} else if assertionFailure == nil {
333
panic(e)
334
}
335
}
336
}()
337
values = actualValue.Call(inValues)
338
return
339
}, nil
340
}
341
342
func (assertion *AsyncAssertion) afterTimeout() <-chan time.Time {
343
if assertion.timeoutInterval >= 0 {
344
return time.After(assertion.timeoutInterval)
345
}
346
347
if assertion.asyncType == AsyncAssertionTypeConsistently {
348
return time.After(assertion.g.DurationBundle.ConsistentlyDuration)
349
} else {
350
if assertion.ctx == nil || assertion.g.DurationBundle.EnforceDefaultTimeoutsWhenUsingContexts {
351
return time.After(assertion.g.DurationBundle.EventuallyTimeout)
352
} else {
353
return nil
354
}
355
}
356
}
357
358
func (assertion *AsyncAssertion) afterPolling() <-chan time.Time {
359
if assertion.pollingInterval >= 0 {
360
return time.After(assertion.pollingInterval)
361
}
362
if assertion.asyncType == AsyncAssertionTypeConsistently {
363
return time.After(assertion.g.DurationBundle.ConsistentlyPollingInterval)
364
} else {
365
return time.After(assertion.g.DurationBundle.EventuallyPollingInterval)
366
}
367
}
368
369
func (assertion *AsyncAssertion) matcherSaysStopTrying(matcher types.GomegaMatcher, value any) bool {
370
if assertion.actualIsFunc || types.MatchMayChangeInTheFuture(matcher, value) {
371
return false
372
}
373
return true
374
}
375
376
func (assertion *AsyncAssertion) pollMatcher(matcher types.GomegaMatcher, value any) (matches bool, err error) {
377
defer func() {
378
if e := recover(); e != nil {
379
if _, isAsyncError := AsPollingSignalError(e); isAsyncError {
380
err = e.(error)
381
} else {
382
panic(e)
383
}
384
}
385
}()
386
387
matches, err = matcher.Match(value)
388
389
return
390
}
391
392
func (assertion *AsyncAssertion) match(matcher types.GomegaMatcher, desiredMatch bool, optionalDescription ...any) bool {
393
timer := time.Now()
394
timeout := assertion.afterTimeout()
395
lock := sync.Mutex{}
396
397
var matches, hasLastValidActual bool
398
var actual, lastValidActual any
399
var actualErr, matcherErr error
400
var oracleMatcherSaysStop bool
401
402
assertion.g.THelper()
403
404
pollActual, buildActualPollerErr := assertion.buildActualPoller()
405
if buildActualPollerErr != nil {
406
assertion.g.Fail(buildActualPollerErr.Error(), 2+assertion.offset)
407
return false
408
}
409
410
actual, actualErr = pollActual()
411
if actualErr == nil {
412
lastValidActual = actual
413
hasLastValidActual = true
414
oracleMatcherSaysStop = assertion.matcherSaysStopTrying(matcher, actual)
415
matches, matcherErr = assertion.pollMatcher(matcher, actual)
416
}
417
418
renderError := func(preamble string, err error) string {
419
message := ""
420
if pollingSignalErr, ok := AsPollingSignalError(err); ok {
421
message = err.Error()
422
for _, attachment := range pollingSignalErr.Attachments {
423
message += fmt.Sprintf("\n%s:\n", attachment.Description)
424
message += format.Object(attachment.Object, 1)
425
}
426
} else {
427
message = preamble + "\n" + format.Object(err, 1)
428
}
429
return message
430
}
431
432
messageGenerator := func() string {
433
// can be called out of band by Ginkgo if the user requests a progress report
434
lock.Lock()
435
defer lock.Unlock()
436
message := ""
437
438
if actualErr == nil {
439
if matcherErr == nil {
440
if desiredMatch != matches {
441
if desiredMatch {
442
message += matcher.FailureMessage(actual)
443
} else {
444
message += matcher.NegatedFailureMessage(actual)
445
}
446
} else {
447
if assertion.asyncType == AsyncAssertionTypeConsistently {
448
message += "There is no failure as the matcher passed to Consistently has not yet failed"
449
} else {
450
message += "There is no failure as the matcher passed to Eventually succeeded on its most recent iteration"
451
}
452
}
453
} else {
454
var fgErr formattedGomegaError
455
if errors.As(matcherErr, &fgErr) {
456
message += fgErr.FormattedGomegaError() + "\n"
457
} else {
458
message += renderError(fmt.Sprintf("The matcher passed to %s returned the following error:", assertion.asyncType), matcherErr)
459
}
460
}
461
} else {
462
var fgErr formattedGomegaError
463
if errors.As(actualErr, &fgErr) {
464
message += fgErr.FormattedGomegaError() + "\n"
465
} else {
466
message += renderError(fmt.Sprintf("The function passed to %s returned the following error:", assertion.asyncType), actualErr)
467
}
468
if hasLastValidActual {
469
message += fmt.Sprintf("\nAt one point, however, the function did return successfully.\nYet, %s failed because", assertion.asyncType)
470
_, e := matcher.Match(lastValidActual)
471
if e != nil {
472
message += renderError(" the matcher returned the following error:", e)
473
} else {
474
message += " the matcher was not satisfied:\n"
475
if desiredMatch {
476
message += matcher.FailureMessage(lastValidActual)
477
} else {
478
message += matcher.NegatedFailureMessage(lastValidActual)
479
}
480
}
481
}
482
}
483
484
description := assertion.buildDescription(optionalDescription...)
485
return fmt.Sprintf("%s%s", description, message)
486
}
487
488
fail := func(preamble string) {
489
assertion.g.THelper()
490
assertion.g.Fail(fmt.Sprintf("%s after %.3fs.\n%s", preamble, time.Since(timer).Seconds(), messageGenerator()), 3+assertion.offset)
491
}
492
493
var contextDone <-chan struct{}
494
if assertion.ctx != nil {
495
contextDone = assertion.ctx.Done()
496
if v, ok := assertion.ctx.Value("GINKGO_SPEC_CONTEXT").(contextWithAttachProgressReporter); ok {
497
detach := v.AttachProgressReporter(messageGenerator)
498
defer detach()
499
}
500
}
501
502
// Used to count the number of times in a row a step passed
503
passedRepeatedlyCount := 0
504
for {
505
var nextPoll <-chan time.Time = nil
506
var isTryAgainAfterError = false
507
508
for _, err := range []error{actualErr, matcherErr} {
509
if pollingSignalErr, ok := AsPollingSignalError(err); ok {
510
if pollingSignalErr.IsStopTrying() {
511
if pollingSignalErr.IsSuccessful() {
512
if assertion.asyncType == AsyncAssertionTypeEventually {
513
fail("Told to stop trying (and ignoring call to Successfully(), as it is only relevant with Consistently)")
514
} else {
515
return true // early escape hatch for Consistently
516
}
517
} else {
518
fail("Told to stop trying")
519
}
520
return false
521
}
522
if pollingSignalErr.IsTryAgainAfter() {
523
nextPoll = time.After(pollingSignalErr.TryAgainDuration())
524
isTryAgainAfterError = true
525
}
526
}
527
}
528
529
if actualErr == nil && matcherErr == nil && matches == desiredMatch {
530
if assertion.asyncType == AsyncAssertionTypeEventually {
531
passedRepeatedlyCount += 1
532
if passedRepeatedlyCount == assertion.mustPassRepeatedly {
533
return true
534
}
535
}
536
} else if !isTryAgainAfterError {
537
if assertion.asyncType == AsyncAssertionTypeConsistently {
538
fail("Failed")
539
return false
540
}
541
// Reset the consecutive pass count
542
passedRepeatedlyCount = 0
543
}
544
545
if oracleMatcherSaysStop {
546
if assertion.asyncType == AsyncAssertionTypeEventually {
547
fail("No future change is possible. Bailing out early")
548
return false
549
} else {
550
return true
551
}
552
}
553
554
if nextPoll == nil {
555
nextPoll = assertion.afterPolling()
556
}
557
558
select {
559
case <-nextPoll:
560
a, e := pollActual()
561
lock.Lock()
562
actual, actualErr = a, e
563
lock.Unlock()
564
if actualErr == nil {
565
lock.Lock()
566
lastValidActual = actual
567
hasLastValidActual = true
568
lock.Unlock()
569
oracleMatcherSaysStop = assertion.matcherSaysStopTrying(matcher, actual)
570
m, e := assertion.pollMatcher(matcher, actual)
571
lock.Lock()
572
matches, matcherErr = m, e
573
lock.Unlock()
574
}
575
case <-contextDone:
576
err := context.Cause(assertion.ctx)
577
if err != nil && err != context.Canceled {
578
fail(fmt.Sprintf("Context was cancelled (cause: %s)", err))
579
} else {
580
fail("Context was cancelled")
581
}
582
return false
583
case <-timeout:
584
if assertion.asyncType == AsyncAssertionTypeEventually {
585
fail("Timed out")
586
return false
587
} else {
588
if isTryAgainAfterError {
589
fail("Timed out while waiting on TryAgainAfter")
590
return false
591
}
592
return true
593
}
594
}
595
}
596
}
597
598