Path: blob/main/vendor/github.com/onsi/gomega/matchers/have_field.go
2880 views
package matchers12import (3"fmt"4"reflect"5"strings"67"github.com/onsi/gomega/format"8)910// missingFieldError represents a missing field extraction error that11// HaveExistingFieldMatcher can ignore, as opposed to other, sever field12// extraction errors, such as nil pointers, et cetera.13type missingFieldError string1415func (e missingFieldError) Error() string {16return string(e)17}1819func extractField(actual any, field string, matchername string) (any, error) {20fields := strings.SplitN(field, ".", 2)21actualValue := reflect.ValueOf(actual)2223if actualValue.Kind() == reflect.Ptr {24actualValue = actualValue.Elem()25}26if actualValue == (reflect.Value{}) {27return nil, fmt.Errorf("%s encountered nil while dereferencing a pointer of type %T.", matchername, actual)28}2930if actualValue.Kind() != reflect.Struct {31return nil, fmt.Errorf("%s encountered:\n%s\nWhich is not a struct.", matchername, format.Object(actual, 1))32}3334var extractedValue reflect.Value3536if strings.HasSuffix(fields[0], "()") {37extractedValue = actualValue.MethodByName(strings.TrimSuffix(fields[0], "()"))38if extractedValue == (reflect.Value{}) && actualValue.CanAddr() {39extractedValue = actualValue.Addr().MethodByName(strings.TrimSuffix(fields[0], "()"))40}41if extractedValue == (reflect.Value{}) {42ptr := reflect.New(actualValue.Type())43ptr.Elem().Set(actualValue)44extractedValue = ptr.MethodByName(strings.TrimSuffix(fields[0], "()"))45if extractedValue == (reflect.Value{}) {46return nil, missingFieldError(fmt.Sprintf("%s could not find method named '%s' in struct of type %T.", matchername, fields[0], actual))47}48}49t := extractedValue.Type()50if t.NumIn() != 0 || t.NumOut() != 1 {51return nil, fmt.Errorf("%s found an invalid method named '%s' in struct of type %T.\nMethods must take no arguments and return exactly one value.", matchername, fields[0], actual)52}53extractedValue = extractedValue.Call([]reflect.Value{})[0]54} else {55extractedValue = actualValue.FieldByName(fields[0])56if extractedValue == (reflect.Value{}) {57return nil, missingFieldError(fmt.Sprintf("%s could not find field named '%s' in struct:\n%s", matchername, fields[0], format.Object(actual, 1)))58}59}6061if len(fields) == 1 {62return extractedValue.Interface(), nil63} else {64return extractField(extractedValue.Interface(), fields[1], matchername)65}66}6768type HaveFieldMatcher struct {69Field string70Expected any71}7273func (matcher *HaveFieldMatcher) expectedMatcher() omegaMatcher {74var isMatcher bool75expectedMatcher, isMatcher := matcher.Expected.(omegaMatcher)76if !isMatcher {77expectedMatcher = &EqualMatcher{Expected: matcher.Expected}78}79return expectedMatcher80}8182func (matcher *HaveFieldMatcher) Match(actual any) (success bool, err error) {83extractedField, err := extractField(actual, matcher.Field, "HaveField")84if err != nil {85return false, err86}8788return matcher.expectedMatcher().Match(extractedField)89}9091func (matcher *HaveFieldMatcher) FailureMessage(actual any) (message string) {92extractedField, err := extractField(actual, matcher.Field, "HaveField")93if err != nil {94// this really shouldn't happen95return fmt.Sprintf("Failed to extract field '%s': %s", matcher.Field, err)96}97message = fmt.Sprintf("Value for field '%s' failed to satisfy matcher.\n", matcher.Field)98message += matcher.expectedMatcher().FailureMessage(extractedField)99100return message101}102103func (matcher *HaveFieldMatcher) NegatedFailureMessage(actual any) (message string) {104extractedField, err := extractField(actual, matcher.Field, "HaveField")105if err != nil {106// this really shouldn't happen107return fmt.Sprintf("Failed to extract field '%s': %s", matcher.Field, err)108}109message = fmt.Sprintf("Value for field '%s' satisfied matcher, but should not have.\n", matcher.Field)110message += matcher.expectedMatcher().NegatedFailureMessage(extractedField)111112return message113}114115116