Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
kardolus
GitHub Repository: kardolus/chatgpt-cli
Path: blob/main/vendor/github.com/spf13/cobra/flag_groups.go
2875 views
1
// Copyright 2013-2023 The Cobra Authors
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
// http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14
15
package cobra
16
17
import (
18
"fmt"
19
"sort"
20
"strings"
21
22
flag "github.com/spf13/pflag"
23
)
24
25
const (
26
requiredAsGroupAnnotation = "cobra_annotation_required_if_others_set"
27
oneRequiredAnnotation = "cobra_annotation_one_required"
28
mutuallyExclusiveAnnotation = "cobra_annotation_mutually_exclusive"
29
)
30
31
// MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
32
// if the command is invoked with a subset (but not all) of the given flags.
33
func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
34
c.mergePersistentFlags()
35
for _, v := range flagNames {
36
f := c.Flags().Lookup(v)
37
if f == nil {
38
panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
39
}
40
if err := c.Flags().SetAnnotation(v, requiredAsGroupAnnotation, append(f.Annotations[requiredAsGroupAnnotation], strings.Join(flagNames, " "))); err != nil {
41
// Only errs if the flag isn't found.
42
panic(err)
43
}
44
}
45
}
46
47
// MarkFlagsOneRequired marks the given flags with annotations so that Cobra errors
48
// if the command is invoked without at least one flag from the given set of flags.
49
func (c *Command) MarkFlagsOneRequired(flagNames ...string) {
50
c.mergePersistentFlags()
51
for _, v := range flagNames {
52
f := c.Flags().Lookup(v)
53
if f == nil {
54
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v))
55
}
56
if err := c.Flags().SetAnnotation(v, oneRequiredAnnotation, append(f.Annotations[oneRequiredAnnotation], strings.Join(flagNames, " "))); err != nil {
57
// Only errs if the flag isn't found.
58
panic(err)
59
}
60
}
61
}
62
63
// MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
64
// if the command is invoked with more than one flag from the given set of flags.
65
func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
66
c.mergePersistentFlags()
67
for _, v := range flagNames {
68
f := c.Flags().Lookup(v)
69
if f == nil {
70
panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
71
}
72
// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
73
if err := c.Flags().SetAnnotation(v, mutuallyExclusiveAnnotation, append(f.Annotations[mutuallyExclusiveAnnotation], strings.Join(flagNames, " "))); err != nil {
74
panic(err)
75
}
76
}
77
}
78
79
// ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the
80
// first error encountered.
81
func (c *Command) ValidateFlagGroups() error {
82
if c.DisableFlagParsing {
83
return nil
84
}
85
86
flags := c.Flags()
87
88
// groupStatus format is the list of flags as a unique ID,
89
// then a map of each flag name and whether it is set or not.
90
groupStatus := map[string]map[string]bool{}
91
oneRequiredGroupStatus := map[string]map[string]bool{}
92
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
93
flags.VisitAll(func(pflag *flag.Flag) {
94
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
95
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
96
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
97
})
98
99
if err := validateRequiredFlagGroups(groupStatus); err != nil {
100
return err
101
}
102
if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil {
103
return err
104
}
105
if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
106
return err
107
}
108
return nil
109
}
110
111
func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {
112
for _, fname := range flagnames {
113
f := fs.Lookup(fname)
114
if f == nil {
115
return false
116
}
117
}
118
return true
119
}
120
121
func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) {
122
groupInfo, found := pflag.Annotations[annotation]
123
if found {
124
for _, group := range groupInfo {
125
if groupStatus[group] == nil {
126
flagnames := strings.Split(group, " ")
127
128
// Only consider this flag group at all if all the flags are defined.
129
if !hasAllFlags(flags, flagnames...) {
130
continue
131
}
132
133
groupStatus[group] = make(map[string]bool, len(flagnames))
134
for _, name := range flagnames {
135
groupStatus[group][name] = false
136
}
137
}
138
139
groupStatus[group][pflag.Name] = pflag.Changed
140
}
141
}
142
}
143
144
func validateRequiredFlagGroups(data map[string]map[string]bool) error {
145
keys := sortedKeys(data)
146
for _, flagList := range keys {
147
flagnameAndStatus := data[flagList]
148
149
unset := []string{}
150
for flagname, isSet := range flagnameAndStatus {
151
if !isSet {
152
unset = append(unset, flagname)
153
}
154
}
155
if len(unset) == len(flagnameAndStatus) || len(unset) == 0 {
156
continue
157
}
158
159
// Sort values, so they can be tested/scripted against consistently.
160
sort.Strings(unset)
161
return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset)
162
}
163
164
return nil
165
}
166
167
func validateOneRequiredFlagGroups(data map[string]map[string]bool) error {
168
keys := sortedKeys(data)
169
for _, flagList := range keys {
170
flagnameAndStatus := data[flagList]
171
var set []string
172
for flagname, isSet := range flagnameAndStatus {
173
if isSet {
174
set = append(set, flagname)
175
}
176
}
177
if len(set) >= 1 {
178
continue
179
}
180
181
// Sort values, so they can be tested/scripted against consistently.
182
sort.Strings(set)
183
return fmt.Errorf("at least one of the flags in the group [%v] is required", flagList)
184
}
185
return nil
186
}
187
188
func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
189
keys := sortedKeys(data)
190
for _, flagList := range keys {
191
flagnameAndStatus := data[flagList]
192
var set []string
193
for flagname, isSet := range flagnameAndStatus {
194
if isSet {
195
set = append(set, flagname)
196
}
197
}
198
if len(set) == 0 || len(set) == 1 {
199
continue
200
}
201
202
// Sort values, so they can be tested/scripted against consistently.
203
sort.Strings(set)
204
return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set)
205
}
206
return nil
207
}
208
209
func sortedKeys(m map[string]map[string]bool) []string {
210
keys := make([]string, len(m))
211
i := 0
212
for k := range m {
213
keys[i] = k
214
i++
215
}
216
sort.Strings(keys)
217
return keys
218
}
219
220
// enforceFlagGroupsForCompletion will do the following:
221
// - when a flag in a group is present, other flags in the group will be marked required
222
// - when none of the flags in a one-required group are present, all flags in the group will be marked required
223
// - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
224
// This allows the standard completion logic to behave appropriately for flag groups
225
func (c *Command) enforceFlagGroupsForCompletion() {
226
if c.DisableFlagParsing {
227
return
228
}
229
230
flags := c.Flags()
231
groupStatus := map[string]map[string]bool{}
232
oneRequiredGroupStatus := map[string]map[string]bool{}
233
mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
234
c.Flags().VisitAll(func(pflag *flag.Flag) {
235
processFlagForGroupAnnotation(flags, pflag, requiredAsGroupAnnotation, groupStatus)
236
processFlagForGroupAnnotation(flags, pflag, oneRequiredAnnotation, oneRequiredGroupStatus)
237
processFlagForGroupAnnotation(flags, pflag, mutuallyExclusiveAnnotation, mutuallyExclusiveGroupStatus)
238
})
239
240
// If a flag that is part of a group is present, we make all the other flags
241
// of that group required so that the shell completion suggests them automatically
242
for flagList, flagnameAndStatus := range groupStatus {
243
for _, isSet := range flagnameAndStatus {
244
if isSet {
245
// One of the flags of the group is set, mark the other ones as required
246
for _, fName := range strings.Split(flagList, " ") {
247
_ = c.MarkFlagRequired(fName)
248
}
249
}
250
}
251
}
252
253
// If none of the flags of a one-required group are present, we make all the flags
254
// of that group required so that the shell completion suggests them automatically
255
for flagList, flagnameAndStatus := range oneRequiredGroupStatus {
256
isSet := false
257
258
for _, isSet = range flagnameAndStatus {
259
if isSet {
260
break
261
}
262
}
263
264
// None of the flags of the group are set, mark all flags in the group
265
// as required
266
if !isSet {
267
for _, fName := range strings.Split(flagList, " ") {
268
_ = c.MarkFlagRequired(fName)
269
}
270
}
271
}
272
273
// If a flag that is mutually exclusive to others is present, we hide the other
274
// flags of that group so the shell completion does not suggest them
275
for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
276
for flagName, isSet := range flagnameAndStatus {
277
if isSet {
278
// One of the flags of the mutually exclusive group is set, mark the other ones as hidden
279
// Don't mark the flag that is already set as hidden because it may be an
280
// array or slice flag and therefore must continue being suggested
281
for _, fName := range strings.Split(flagList, " ") {
282
if fName != flagName {
283
flag := c.Flags().Lookup(fName)
284
flag.Hidden = true
285
}
286
}
287
}
288
}
289
}
290
}
291
292