Path: blob/main/vendor/github.com/spf13/cobra/flag_groups.go
2875 views
// Copyright 2013-2023 The Cobra Authors1//2// Licensed under the Apache License, Version 2.0 (the "License");3// you may not use this file except in compliance with the License.4// You may obtain a copy of the License at5//6// http://www.apache.org/licenses/LICENSE-2.07//8// Unless required by applicable law or agreed to in writing, software9// distributed under the License is distributed on an "AS IS" BASIS,10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.11// See the License for the specific language governing permissions and12// limitations under the License.1314package cobra1516import (17"fmt"18"sort"19"strings"2021flag "github.com/spf13/pflag"22)2324const (25requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set"26oneRequiredAnnotation = "cobra_annotation_one_required"27mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive"28)2930// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors31// if the command is invoked with a subset (but not all) of the given flags.32func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {33c.mergePersistentFlags()34for _, v := range flagNames {35f := c.Flags().Lookup(v)36if f == nil {37panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))38}39if err := c.Flags().SetAnnotation(v, requiredAsGroupAnnotation, append(f.Annotations[requiredAsGroupAnnotation], strings.Join(flagNames, " "))); err != nil {40// Only errs if the flag isn't found.41panic(err)42}43}44}4546// MarkFlagsOneRequired marks the given flags with annotations so that Cobra errors47// if the command is invoked without at least one flag from the given set of flags.48func (c *Command) MarkFlagsOneRequired(flagNames ...string) {49c.mergePersistentFlags()50for _, v := range flagNames {51f := c.Flags().Lookup(v)52if f == nil {53panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v))54}55if err := c.Flags().SetAnnotation(v, oneRequiredAnnotation, append(f.Annotations[oneRequiredAnnotation], strings.Join(flagNames, " "))); err != nil {56// Only errs if the flag isn't found.57panic(err)58}59}60}6162// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors63// if the command is invoked with more than one flag from the given set of flags.64func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {65c.mergePersistentFlags()66for _, v := range flagNames {67f := c.Flags().Lookup(v)68if f == nil {69panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))70}71// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.72if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAnnotation, append(f.Annotations[mutuallyExclusiveAnnotation], strings.Join(flagNames, " "))); err != nil {73panic(err)74}75}76}7778// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the79// first error encountered.80func (c *Command) ValidateFlagGroups() error {81if c.DisableFlagParsing {82return nil83}8485flags := c.Flags()8687// groupStatus format is the list of flags as a unique ID,88// then a map of each flag name and whether it is set or not.89groupStatus := map[string]map[string]bool{}90oneRequiredGroupStatus := map[string]map[string]bool{}91mutuallyExclusiveGroupStatus := map[string]map[string]bool{}92flags.VisitAll(func(pflag *flag.Flag) {93processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)94processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)95processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)96})9798if err := validateRequiredFlagGroups(groupStatus); err != nil {99return err100}101if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil {102return err103}104if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {105return err106}107return nil108}109110func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {111for _, fname := range flagnames {112f := fs.Lookup(fname)113if f == nil {114return false115}116}117return true118}119120func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) {121groupInfo, found := pflag.Annotations[annotation]122if found {123for _, group := range groupInfo {124if groupStatus[group] == nil {125flagnames := strings.Split(group, " ")126127// Only consider this flag group at all if all the flags are defined.128if !hasAllFlags(flags, flagnames...) {129continue130}131132groupStatus[group] = make(map[string]bool, len(flagnames))133for _, name := range flagnames {134groupStatus[group][name] = false135}136}137138groupStatus[group][pflag.Name] = pflag.Changed139}140}141}142143func validateRequiredFlagGroups(data map[string]map[string]bool) error {144keys := sortedKeys(data)145for _, flagList := range keys {146flagnameAndStatus := data[flagList]147148unset := []string{}149for flagname, isSet := range flagnameAndStatus {150if !isSet {151unset = append(unset, flagname)152}153}154if len(unset) == len(flagnameAndStatus) || len(unset) == 0 {155continue156}157158// Sort values, so they can be tested/scripted against consistently.159sort.Strings(unset)160return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset)161}162163return nil164}165166func validateOneRequiredFlagGroups(data map[string]map[string]bool) error {167keys := sortedKeys(data)168for _, flagList := range keys {169flagnameAndStatus := data[flagList]170var set []string171for flagname, isSet := range flagnameAndStatus {172if isSet {173set = append(set, flagname)174}175}176if len(set) >= 1 {177continue178}179180// Sort values, so they can be tested/scripted against consistently.181sort.Strings(set)182return fmt.Errorf("at least one of the flags in the group [%v] is required", flagList)183}184return nil185}186187func validateExclusiveFlagGroups(data map[string]map[string]bool) error {188keys := sortedKeys(data)189for _, flagList := range keys {190flagnameAndStatus := data[flagList]191var set []string192for flagname, isSet := range flagnameAndStatus {193if isSet {194set = append(set, flagname)195}196}197if len(set) == 0 || len(set) == 1 {198continue199}200201// Sort values, so they can be tested/scripted against consistently.202sort.Strings(set)203return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set)204}205return nil206}207208func sortedKeys(m map[string]map[string]bool) []string {209keys := make([]string, len(m))210i := 0211for k := range m {212keys[i] = k213i++214}215sort.Strings(keys)216return keys217}218219// enforceFlagGroupsForCompletion will do the following:220// - when a flag in a group is present, other flags in the group will be marked required221// - when none of the flags in a one-required group are present, all flags in the group will be marked required222// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden223// This allows the standard completion logic to behave appropriately for flag groups224func (c *Command) enforceFlagGroupsForCompletion() {225if c.DisableFlagParsing {226return227}228229flags := c.Flags()230groupStatus := map[string]map[string]bool{}231oneRequiredGroupStatus := map[string]map[string]bool{}232mutuallyExclusiveGroupStatus := map[string]map[string]bool{}233c.Flags().VisitAll(func(pflag *flag.Flag) {234processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)235processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)236processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)237})238239// If a flag that is part of a group is present, we make all the other flags240// of that group required so that the shell completion suggests them automatically241for flagList, flagnameAndStatus := range groupStatus {242for _, isSet := range flagnameAndStatus {243if isSet {244// One of the flags of the group is set, mark the other ones as required245for _, fName := range strings.Split(flagList, " ") {246_ = c.MarkFlagRequired(fName)247}248}249}250}251252// If none of the flags of a one-required group are present, we make all the flags253// of that group required so that the shell completion suggests them automatically254for flagList, flagnameAndStatus := range oneRequiredGroupStatus {255isSet := false256257for _, isSet = range flagnameAndStatus {258if isSet {259break260}261}262263// None of the flags of the group are set, mark all flags in the group264// as required265if !isSet {266for _, fName := range strings.Split(flagList, " ") {267_ = c.MarkFlagRequired(fName)268}269}270}271272// If a flag that is mutually exclusive to others is present, we hide the other273// flags of that group so the shell completion does not suggest them274for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {275for flagName, isSet := range flagnameAndStatus {276if isSet {277// One of the flags of the mutually exclusive group is set, mark the other ones as hidden278// Don't mark the flag that is already set as hidden because it may be an279// array or slice flag and therefore must continue being suggested280for _, fName := range strings.Split(flagList, " ") {281if fName != flagName {282flag := c.Flags().Lookup(fName)283flag.Hidden = true284}285}286}287}288}289}290291292