Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
kardolus
GitHub Repository: kardolus/chatgpt-cli
Path: blob/main/api/client/client.go
2865 views
1
package client
2
3
import (
4
"bytes"
5
"context"
6
"encoding/base64"
7
"encoding/json"
8
"errors"
9
"fmt"
10
"io"
11
"mime/multipart"
12
stdhttp "net/http"
13
"net/textproto"
14
"net/url"
15
"os"
16
"path/filepath"
17
"sort"
18
"strings"
19
"time"
20
"unicode/utf8"
21
22
"github.com/kardolus/chatgpt-cli/api"
23
"github.com/kardolus/chatgpt-cli/api/http"
24
"github.com/kardolus/chatgpt-cli/cmd/chatgpt/utils"
25
"github.com/kardolus/chatgpt-cli/config"
26
"github.com/kardolus/chatgpt-cli/history"
27
"github.com/kardolus/chatgpt-cli/internal"
28
29
"go.uber.org/zap"
30
"golang.org/x/text/cases"
31
"golang.org/x/text/language"
32
)
33
34
const (
35
AssistantRole = "assistant"
36
ErrEmptyResponse = "empty response"
37
ErrMissingMCPAPIKey = "the %s api key is not configured"
38
ErrUnsupportedProvider = "unsupported MCP provider"
39
ErrHistoryTracking = "history tracking needs to be enabled to use this feature"
40
MaxTokenBufferPercentage = 20
41
SystemRole = "system"
42
UserRole = "user"
43
FunctionRole = "function"
44
InteractiveThreadPrefix = "int_"
45
SearchModelPattern = "-search"
46
ApifyURL = "https://api.apify.com/v2/acts/"
47
ApifyPath = "/run-sync-get-dataset-items"
48
ApifyProxyConfig = "proxyConfiguration"
49
gptPrefix = "gpt"
50
o1Prefix = "o1"
51
o1ProPattern = "o1-pro"
52
gpt5Pattern = "gpt-5"
53
audioType = "input_audio"
54
imageURLType = "image_url"
55
messageType = "message"
56
outputTextType = "output_text"
57
imageContent = "data:%s;base64,%s"
58
httpScheme = "http"
59
httpsScheme = "https"
60
bufferSize = 512
61
)
62
63
type Timer interface {
64
Now() time.Time
65
}
66
67
type RealTime struct {
68
}
69
70
func (r *RealTime) Now() time.Time {
71
return time.Now()
72
}
73
74
type FileReader interface {
75
ReadFile(name string) ([]byte, error)
76
ReadBufferFromFile(file *os.File) ([]byte, error)
77
Open(name string) (*os.File, error)
78
}
79
80
type RealFileReader struct{}
81
82
func (r *RealFileReader) Open(name string) (*os.File, error) {
83
return os.Open(name)
84
}
85
86
func (r *RealFileReader) ReadFile(name string) ([]byte, error) {
87
return os.ReadFile(name)
88
}
89
90
func (r *RealFileReader) ReadBufferFromFile(file *os.File) ([]byte, error) {
91
buffer := make([]byte, bufferSize)
92
_, err := file.Read(buffer)
93
94
return buffer, err
95
}
96
97
type FileWriter interface {
98
Write(file *os.File, buf []byte) error
99
Create(name string) (*os.File, error)
100
}
101
102
type RealFileWriter struct{}
103
104
func (w *RealFileWriter) Create(name string) (*os.File, error) {
105
return os.Create(name)
106
}
107
108
func (r *RealFileWriter) Write(file *os.File, buf []byte) error {
109
_, err := file.Write(buf)
110
return err
111
}
112
113
type Client struct {
114
Config config.Config
115
History []history.History
116
caller http.Caller
117
historyStore history.Store
118
timer Timer
119
reader FileReader
120
writer FileWriter
121
}
122
123
func New(callerFactory http.CallerFactory, hs history.Store, t Timer, r FileReader, w FileWriter, cfg config.Config, interactiveMode bool) *Client {
124
caller := callerFactory(cfg)
125
126
if interactiveMode && cfg.AutoCreateNewThread {
127
hs.SetThread(internal.GenerateUniqueSlug(InteractiveThreadPrefix))
128
} else {
129
hs.SetThread(cfg.Thread)
130
}
131
132
return &Client{
133
Config: cfg,
134
caller: caller,
135
historyStore: hs,
136
timer: t,
137
reader: r,
138
writer: w,
139
}
140
}
141
142
func (c *Client) WithContextWindow(window int) *Client {
143
c.Config.ContextWindow = window
144
return c
145
}
146
147
func (c *Client) WithServiceURL(url string) *Client {
148
c.Config.URL = url
149
return c
150
}
151
152
// InjectMCPContext calls an MCP plugin (e.g. Apify) with the given parameters,
153
// retrieves the result, and adds it to the chat history as a function message.
154
// The result is formatted as a string and tagged with the function name.
155
func (c *Client) InjectMCPContext(mcp api.MCPRequest) error {
156
if c.Config.OmitHistory {
157
return errors.New(ErrHistoryTracking)
158
}
159
160
endpoint, headers, body, err := c.buildMCPRequest(mcp)
161
if err != nil {
162
return err
163
}
164
165
c.printRequestDebugInfo(endpoint, body, headers)
166
167
raw, err := c.caller.PostWithHeaders(endpoint, body, headers)
168
if err != nil {
169
return err
170
}
171
172
c.printResponseDebugInfo(raw)
173
174
formatted := formatMCPResponse(raw, mcp.Function)
175
176
c.initHistory()
177
c.History = append(c.History, history.History{
178
Message: api.Message{
179
Role: FunctionRole,
180
Name: strings.ReplaceAll(mcp.Function, "~", "-"),
181
Content: formatted,
182
},
183
Timestamp: c.timer.Now(),
184
})
185
c.truncateHistory()
186
187
return c.historyStore.Write(c.History)
188
}
189
190
// ListModels retrieves a list of all available models from the OpenAI API.
191
// The models are returned as a slice of strings, each entry representing a model ID.
192
// Models that have an ID starting with 'gpt' are included.
193
// The currently active model is marked with an asterisk (*) in the list.
194
// In case of an error during the retrieval or processing of the models,
195
// the method returns an error. If the API response is empty, an error is returned as well.
196
func (c *Client) ListModels() ([]string, error) {
197
var result []string
198
199
endpoint := c.getEndpoint(c.Config.ModelsPath)
200
201
c.printRequestDebugInfo(endpoint, nil, nil)
202
203
raw, err := c.caller.Get(c.getEndpoint(c.Config.ModelsPath))
204
c.printResponseDebugInfo(raw)
205
206
if err != nil {
207
return nil, err
208
}
209
210
var response api.ListModelsResponse
211
if err := c.processResponse(raw, &response); err != nil {
212
return nil, err
213
}
214
215
sort.Slice(response.Data, func(i, j int) bool {
216
return response.Data[i].Id < response.Data[j].Id
217
})
218
219
for _, model := range response.Data {
220
if strings.HasPrefix(model.Id, gptPrefix) || strings.HasPrefix(model.Id, o1Prefix) {
221
if model.Id != c.Config.Model {
222
result = append(result, fmt.Sprintf("- %s", model.Id))
223
continue
224
}
225
result = append(result, fmt.Sprintf("* %s (current)", model.Id))
226
}
227
}
228
229
return result, nil
230
}
231
232
// ProvideContext adds custom context to the client's history by converting the
233
// provided string into a series of messages. This allows the ChatGPT API to have
234
// prior knowledge of the provided context when generating responses.
235
//
236
// The context string should contain the text you want to provide as context,
237
// and the method will split it into messages, preserving punctuation and special
238
// characters.
239
func (c *Client) ProvideContext(context string) {
240
c.initHistory()
241
historyEntries := c.createHistoryEntriesFromString(context)
242
c.History = append(c.History, historyEntries...)
243
}
244
245
// Query sends a query to the API, returning the response as a string along with the token usage.
246
//
247
// It takes a context `ctx` and an input string, constructs a request body, and makes a POST API call.
248
// The context allows for request scoping, timeouts, and cancellation handling.
249
//
250
// Returns the API response string, the number of tokens used, and an error if any issues occur.
251
// If the response contains choices, it decodes the JSON and returns the content of the first choice.
252
//
253
// Parameters:
254
// - ctx: A context.Context that controls request cancellation and deadlines.
255
// - input: The query string to send to the API.
256
//
257
// Returns:
258
// - string: The content of the first response choice from the API.
259
// - int: The total number of tokens used in the request.
260
// - error: An error if the request fails or the response is invalid.
261
func (c *Client) Query(ctx context.Context, input string) (string, int, error) {
262
c.prepareQuery(input)
263
264
body, err := c.createBody(ctx, false)
265
if err != nil {
266
return "", 0, err
267
}
268
269
endpoint := c.getChatEndpoint()
270
271
c.printRequestDebugInfo(endpoint, body, nil)
272
273
raw, err := c.caller.Post(endpoint, body, false)
274
c.printResponseDebugInfo(raw)
275
276
if err != nil {
277
return "", 0, err
278
}
279
280
var (
281
response string
282
tokensUsed int
283
)
284
285
caps := GetCapabilities(c.Config.Model)
286
287
if caps.UsesResponsesAPI {
288
var res api.ResponsesResponse
289
if err := c.processResponse(raw, &res); err != nil {
290
return "", 0, err
291
}
292
tokensUsed = res.Usage.TotalTokens
293
294
for _, output := range res.Output {
295
if output.Type != messageType {
296
continue
297
}
298
for _, content := range output.Content {
299
if content.Type == outputTextType {
300
response = content.Text
301
break
302
}
303
}
304
}
305
306
if response == "" {
307
return "", tokensUsed, errors.New("no response returned")
308
}
309
} else {
310
var res api.CompletionsResponse
311
if err := c.processResponse(raw, &res); err != nil {
312
return "", 0, err
313
}
314
tokensUsed = res.Usage.TotalTokens
315
316
if len(res.Choices) == 0 {
317
return "", tokensUsed, errors.New("no responses returned")
318
}
319
320
var ok bool
321
response, ok = res.Choices[0].Message.Content.(string)
322
if !ok {
323
return "", tokensUsed, errors.New("response cannot be converted to a string")
324
}
325
}
326
327
c.updateHistory(response)
328
329
return response, tokensUsed, nil
330
}
331
332
// Stream sends a query to the API and processes the response as a stream.
333
//
334
// It takes a context `ctx` and an input string, constructs a request body, and makes a POST API call.
335
// The context allows for request scoping, timeouts, and cancellation handling.
336
//
337
// The method creates a request body with the input and calls the API using the `Post` method.
338
// The actual processing of the streamed response is handled inside the `Post` method.
339
//
340
// Parameters:
341
// - ctx: A context.Context that controls request cancellation and deadlines.
342
// - input: The query string to send to the API.
343
//
344
// Returns:
345
// - error: An error if the request fails or the response is invalid.
346
func (c *Client) Stream(ctx context.Context, input string) error {
347
c.prepareQuery(input)
348
349
body, err := c.createBody(ctx, true)
350
if err != nil {
351
return err
352
}
353
354
endpoint := c.getChatEndpoint()
355
356
c.printRequestDebugInfo(endpoint, body, nil)
357
358
result, err := c.caller.Post(endpoint, body, true)
359
if err != nil {
360
return err
361
}
362
363
c.updateHistory(string(result))
364
365
return nil
366
}
367
368
// SynthesizeSpeech converts the given input text into speech using the configured TTS model,
369
// and writes the resulting audio to the specified output file.
370
//
371
// The audio format is inferred from the output file's extension (e.g., "mp3", "wav") and sent
372
// as the "response_format" in the request to the OpenAI speech synthesis endpoint.
373
//
374
// Parameters:
375
// - inputText: The text to synthesize into speech.
376
// - outputPath: The path to the output audio file. The file extension determines the response format.
377
//
378
// Returns an error if the request fails, the response cannot be written, or the file cannot be created.
379
func (c *Client) SynthesizeSpeech(inputText, outputPath string) error {
380
req := api.Speech{
381
Model: c.Config.Model,
382
Voice: c.Config.Voice,
383
Input: inputText,
384
ResponseFormat: getExtension(outputPath),
385
}
386
return c.postAndWriteBinaryOutput(c.getEndpoint(c.Config.SpeechPath), req, outputPath, "binary", nil)
387
}
388
389
// GenerateImage sends a prompt to the configured image generation model (e.g., gpt-image-1)
390
// and writes the resulting image to the specified output path.
391
//
392
// The method performs the following steps:
393
// 1. Sends a POST request to the image generation endpoint with the provided prompt.
394
// 2. Parses the response and extracts the base64-encoded image data.
395
// 3. Decodes the image bytes and writes them to the given outputPath.
396
// 4. Logs the number of bytes written using debug output.
397
//
398
// Parameters:
399
// - inputText: The prompt describing the image to be generated.
400
// - outputPath: The file path where the generated image (e.g., .png) will be saved.
401
//
402
// Returns:
403
// - An error if any part of the request, decoding, or file writing fails.
404
func (c *Client) GenerateImage(inputText, outputPath string) error {
405
req := api.Draw{
406
Model: c.Config.Model,
407
Prompt: inputText,
408
}
409
410
return c.postAndWriteBinaryOutput(
411
c.getEndpoint(c.Config.ImageGenerationsPath),
412
req,
413
outputPath,
414
"image",
415
func(respBytes []byte) ([]byte, error) {
416
var response struct {
417
Data []struct {
418
B64 string `json:"b64_json"`
419
} `json:"data"`
420
}
421
if err := json.Unmarshal(respBytes, &response); err != nil {
422
return nil, fmt.Errorf("failed to decode response: %w", err)
423
}
424
if len(response.Data) == 0 {
425
return nil, fmt.Errorf("no image data returned")
426
}
427
decoded, err := base64.StdEncoding.DecodeString(response.Data[0].B64)
428
if err != nil {
429
return nil, fmt.Errorf("failed to decode base64 image: %w", err)
430
}
431
return decoded, nil
432
},
433
)
434
}
435
436
// EditImage edits an input image using a text prompt and writes the modified image to the specified output path.
437
//
438
// This method sends a multipart/form-data POST request to the image editing endpoint
439
// (typically OpenAI's /v1/images/edits). The request includes:
440
// - The image file to edit.
441
// - A text prompt describing how the image should be modified.
442
// - The model ID (e.g., gpt-image-1).
443
//
444
// The response is expected to contain a base64-encoded image, which is decoded and written to the outputPath.
445
//
446
// Parameters:
447
// - inputText: A text prompt describing the desired modifications to the image.
448
// - inputPath: The file path to the source image (must be a supported format: PNG, JPEG, or WebP).
449
// - outputPath: The file path where the edited image will be saved.
450
//
451
// Returns:
452
// - An error if any step of the process fails: reading the file, building the request, sending it,
453
// decoding the response, or writing the output image.
454
//
455
// Example:
456
//
457
// err := client.EditImage("Add a rainbow in the sky", "input.png", "output.png")
458
// if err != nil {
459
// log.Fatal(err)
460
// }
461
func (c *Client) EditImage(inputText, inputPath, outputPath string) error {
462
endpoint := c.getEndpoint(c.Config.ImageEditsPath)
463
464
file, err := c.reader.Open(inputPath)
465
if err != nil {
466
return fmt.Errorf("failed to open input image: %w", err)
467
}
468
defer file.Close()
469
470
var buf bytes.Buffer
471
writer := multipart.NewWriter(&buf)
472
473
mimeType, err := c.getMimeTypeFromFileContent(inputPath)
474
if err != nil {
475
return fmt.Errorf("failed to detect MIME type: %w", err)
476
}
477
if !strings.HasPrefix(mimeType, "image/") {
478
return fmt.Errorf("unsupported MIME type: %s", mimeType)
479
}
480
481
header := make(textproto.MIMEHeader)
482
header.Set("Content-Disposition", fmt.Sprintf(`form-data; name="image"; filename="%s"`, filepath.Base(inputPath)))
483
header.Set("Content-Type", mimeType)
484
485
part, err := writer.CreatePart(header)
486
if err != nil {
487
return fmt.Errorf("failed to create image part: %w", err)
488
}
489
if _, err := io.Copy(part, file); err != nil {
490
return fmt.Errorf("failed to copy image data: %w", err)
491
}
492
493
if err := writer.WriteField("prompt", inputText); err != nil {
494
return fmt.Errorf("failed to add prompt: %w", err)
495
}
496
if err := writer.WriteField("model", c.Config.Model); err != nil {
497
return fmt.Errorf("failed to add model: %w", err)
498
}
499
500
if err := writer.Close(); err != nil {
501
return fmt.Errorf("failed to close multipart writer: %w", err)
502
}
503
504
c.printRequestDebugInfo(endpoint, buf.Bytes(), map[string]string{
505
"Content-Type": writer.FormDataContentType(),
506
})
507
508
respBytes, err := c.caller.PostWithHeaders(endpoint, buf.Bytes(), map[string]string{
509
c.Config.AuthHeader: fmt.Sprintf("%s %s", c.Config.AuthTokenPrefix, c.Config.APIKey),
510
internal.HeaderContentTypeKey: writer.FormDataContentType(),
511
})
512
if err != nil {
513
return fmt.Errorf("failed to edit image: %w", err)
514
}
515
516
// Parse the JSON and extract b64_json
517
var response struct {
518
Data []struct {
519
B64 string `json:"b64_json"`
520
} `json:"data"`
521
}
522
if err := json.Unmarshal(respBytes, &response); err != nil {
523
return fmt.Errorf("failed to decode response: %w", err)
524
}
525
if len(response.Data) == 0 {
526
return fmt.Errorf("no image data returned")
527
}
528
529
imgBytes, err := base64.StdEncoding.DecodeString(response.Data[0].B64)
530
if err != nil {
531
return fmt.Errorf("failed to decode base64 image: %w", err)
532
}
533
534
outFile, err := c.writer.Create(outputPath)
535
if err != nil {
536
return fmt.Errorf("failed to create output file: %w", err)
537
}
538
defer outFile.Close()
539
540
if err := c.writer.Write(outFile, imgBytes); err != nil {
541
return fmt.Errorf("failed to write image: %w", err)
542
}
543
544
c.printResponseDebugInfo([]byte(fmt.Sprintf("[image] %d bytes written to %s", len(imgBytes), outputPath)))
545
return nil
546
}
547
548
// Transcribe uploads an audio file to the OpenAI transcription endpoint and returns the transcribed text.
549
//
550
// It reads the audio file from the provided `audioPath`, creates a multipart/form-data request with the model name
551
// and the audio file, and sends it to the endpoint defined by the `TranscriptionsPath` in the client config.
552
// The method expects a JSON response containing a "text" field with the transcription result.
553
//
554
// Parameters:
555
// - audioPath: The local file path to the audio file to be transcribed.
556
//
557
// Returns:
558
// - string: The transcribed text from the audio file.
559
// - error: An error if the file can't be read, the request fails, or the response is invalid.
560
//
561
// This method supports formats like mp3, mp4, mpeg, mpga, m4a, wav, and webm, depending on API compatibility.
562
func (c *Client) Transcribe(audioPath string) (string, error) {
563
c.initHistory()
564
565
file, err := c.reader.Open(audioPath)
566
if err != nil {
567
return "", fmt.Errorf("failed to open audio file: %w", err)
568
}
569
defer file.Close()
570
571
var buf bytes.Buffer
572
writer := multipart.NewWriter(&buf)
573
574
_ = writer.WriteField("model", c.Config.Model)
575
576
part, err := writer.CreateFormFile("file", filepath.Base(audioPath))
577
if err != nil {
578
return "", err
579
}
580
if _, err := io.Copy(part, file); err != nil {
581
return "", err
582
}
583
584
if err := writer.Close(); err != nil {
585
return "", err
586
}
587
588
endpoint := c.getEndpoint(c.Config.TranscriptionsPath)
589
headers := map[string]string{
590
internal.HeaderContentTypeKey: writer.FormDataContentType(),
591
c.Config.AuthHeader: fmt.Sprintf("%s %s", c.Config.AuthTokenPrefix, c.Config.APIKey),
592
}
593
594
c.printRequestDebugInfo(endpoint, buf.Bytes(), headers)
595
596
raw, err := c.caller.PostWithHeaders(endpoint, buf.Bytes(), headers)
597
if err != nil {
598
return "", err
599
}
600
601
c.printResponseDebugInfo(raw)
602
603
var res struct {
604
Text string `json:"text"`
605
}
606
if err := json.Unmarshal(raw, &res); err != nil {
607
return "", fmt.Errorf("failed to parse transcription: %w", err)
608
}
609
610
c.History = append(c.History, history.History{
611
Message: api.Message{
612
Role: UserRole,
613
Content: fmt.Sprintf("[transcribe] %s", filepath.Base(audioPath)),
614
},
615
Timestamp: c.timer.Now(),
616
})
617
618
c.History = append(c.History, history.History{
619
Message: api.Message{
620
Role: AssistantRole,
621
Content: res.Text,
622
},
623
Timestamp: c.timer.Now(),
624
})
625
626
c.truncateHistory()
627
628
if !c.Config.OmitHistory {
629
_ = c.historyStore.Write(c.History)
630
}
631
632
return res.Text, nil
633
}
634
635
func (c *Client) appendMediaMessages(ctx context.Context, messages []api.Message) ([]api.Message, error) {
636
if data, ok := ctx.Value(internal.BinaryDataKey).([]byte); ok {
637
content, err := c.createImageContentFromBinary(data)
638
if err != nil {
639
return nil, err
640
}
641
messages = append(messages, api.Message{
642
Role: UserRole,
643
Content: []api.ImageContent{content},
644
})
645
} else if path, ok := ctx.Value(internal.ImagePathKey).(string); ok {
646
content, err := c.createImageContentFromURLOrFile(path)
647
if err != nil {
648
return nil, err
649
}
650
messages = append(messages, api.Message{
651
Role: UserRole,
652
Content: []api.ImageContent{content},
653
})
654
} else if path, ok := ctx.Value(internal.AudioPathKey).(string); ok {
655
content, err := c.createAudioContentFromFile(path)
656
if err != nil {
657
return nil, err
658
}
659
messages = append(messages, api.Message{
660
Role: UserRole,
661
Content: []api.AudioContent{content},
662
})
663
}
664
return messages, nil
665
}
666
667
func (c *Client) createBody(ctx context.Context, stream bool) ([]byte, error) {
668
caps := GetCapabilities(c.Config.Model)
669
670
if caps.UsesResponsesAPI {
671
req, err := c.createResponsesRequest(ctx, stream)
672
if err != nil {
673
return nil, err
674
}
675
return json.Marshal(req)
676
}
677
678
req, err := c.createCompletionsRequest(ctx, stream)
679
if err != nil {
680
return nil, err
681
}
682
return json.Marshal(req)
683
}
684
685
func (c *Client) createCompletionsRequest(ctx context.Context, stream bool) (*api.CompletionsRequest, error) {
686
var messages []api.Message
687
caps := GetCapabilities(c.Config.Model)
688
689
for index, item := range c.History {
690
if caps.OmitFirstSystemMsg && index == 0 {
691
continue
692
}
693
messages = append(messages, item.Message)
694
}
695
696
messages, err := c.appendMediaMessages(ctx, messages)
697
if err != nil {
698
return nil, err
699
}
700
701
req := &api.CompletionsRequest{
702
Messages: messages,
703
Model: c.Config.Model,
704
MaxTokens: c.Config.MaxTokens,
705
FrequencyPenalty: c.Config.FrequencyPenalty,
706
PresencePenalty: c.Config.PresencePenalty,
707
Seed: c.Config.Seed,
708
Stream: stream,
709
}
710
711
if caps.SupportsTemperature {
712
req.Temperature = c.Config.Temperature
713
req.TopP = c.Config.TopP
714
}
715
716
return req, nil
717
}
718
719
func (c *Client) createResponsesRequest(ctx context.Context, stream bool) (*api.ResponsesRequest, error) {
720
var messages []api.Message
721
caps := GetCapabilities(c.Config.Model)
722
723
for index, item := range c.History {
724
if caps.OmitFirstSystemMsg && index == 0 {
725
continue
726
}
727
messages = append(messages, item.Message)
728
}
729
730
messages, err := c.appendMediaMessages(ctx, messages)
731
if err != nil {
732
return nil, err
733
}
734
735
req := &api.ResponsesRequest{
736
Model: c.Config.Model,
737
Input: messages,
738
MaxOutputTokens: c.Config.MaxTokens,
739
Reasoning: api.Reasoning{
740
Effort: c.Config.Effort,
741
},
742
Stream: stream,
743
Temperature: c.Config.Temperature,
744
TopP: c.Config.TopP,
745
}
746
747
return req, nil
748
}
749
750
func (c *Client) createImageContentFromBinary(binary []byte) (api.ImageContent, error) {
751
mime, err := getMimeTypeFromBytes(binary)
752
if err != nil {
753
return api.ImageContent{}, err
754
}
755
756
encoded := base64.StdEncoding.EncodeToString(binary)
757
content := api.ImageContent{
758
Type: imageURLType,
759
ImageURL: struct {
760
URL string `json:"url"`
761
}{
762
URL: fmt.Sprintf(imageContent, mime, encoded),
763
},
764
}
765
766
return content, nil
767
}
768
769
func (c *Client) createAudioContentFromFile(audio string) (api.AudioContent, error) {
770
771
format, err := c.detectAudioFormat(audio)
772
if err != nil {
773
return api.AudioContent{}, err
774
}
775
776
encodedAudio, err := c.base64Encode(audio)
777
if err != nil {
778
return api.AudioContent{}, err
779
}
780
781
return api.AudioContent{
782
Type: audioType,
783
InputAudio: api.InputAudio{
784
Data: encodedAudio,
785
Format: format,
786
},
787
}, nil
788
}
789
790
func (c *Client) createImageContentFromURLOrFile(image string) (api.ImageContent, error) {
791
var content api.ImageContent
792
793
if isValidURL(image) {
794
content = api.ImageContent{
795
Type: imageURLType,
796
ImageURL: struct {
797
URL string `json:"url"`
798
}{
799
URL: image,
800
},
801
}
802
} else {
803
mime, err := c.getMimeTypeFromFileContent(image)
804
if err != nil {
805
return content, err
806
}
807
808
encodedImage, err := c.base64Encode(image)
809
if err != nil {
810
return content, err
811
}
812
813
content = api.ImageContent{
814
Type: imageURLType,
815
ImageURL: struct {
816
URL string `json:"url"`
817
}{
818
URL: fmt.Sprintf(imageContent, mime, encodedImage),
819
},
820
}
821
}
822
823
return content, nil
824
}
825
826
func (c *Client) initHistory() {
827
if len(c.History) != 0 {
828
return
829
}
830
831
if !c.Config.OmitHistory {
832
c.History, _ = c.historyStore.Read()
833
}
834
835
if len(c.History) == 0 {
836
c.History = []history.History{{
837
Message: api.Message{
838
Role: SystemRole,
839
},
840
Timestamp: c.timer.Now(),
841
}}
842
}
843
844
c.History[0].Content = c.Config.Role
845
}
846
847
func (c *Client) addQuery(query string) {
848
message := api.Message{
849
Role: UserRole,
850
Content: query,
851
}
852
853
c.History = append(c.History, history.History{
854
Message: message,
855
Timestamp: c.timer.Now(),
856
})
857
c.truncateHistory()
858
}
859
860
func (c *Client) getChatEndpoint() string {
861
caps := GetCapabilities(c.Config.Model)
862
863
var endpoint string
864
if caps.UsesResponsesAPI {
865
endpoint = c.getEndpoint(c.Config.ResponsesPath)
866
} else {
867
endpoint = c.getEndpoint(c.Config.CompletionsPath)
868
}
869
return endpoint
870
}
871
872
func (c *Client) getEndpoint(path string) string {
873
return c.Config.URL + path
874
}
875
876
func (c *Client) prepareQuery(input string) {
877
c.initHistory()
878
c.addQuery(input)
879
}
880
881
func (c *Client) processResponse(raw []byte, v interface{}) error {
882
if raw == nil {
883
return errors.New(ErrEmptyResponse)
884
}
885
886
if err := json.Unmarshal(raw, v); err != nil {
887
return fmt.Errorf("failed to decode response: %w", err)
888
}
889
890
return nil
891
}
892
893
func (c *Client) truncateHistory() {
894
tokens, rolling := countTokens(c.History)
895
effectiveTokenSize := calculateEffectiveContextWindow(c.Config.ContextWindow, MaxTokenBufferPercentage)
896
897
if tokens <= effectiveTokenSize {
898
return
899
}
900
901
var index int
902
var total int
903
diff := tokens - effectiveTokenSize
904
905
for i := 1; i < len(rolling); i++ {
906
total += rolling[i]
907
if total > diff {
908
index = i
909
break
910
}
911
}
912
913
c.History = append(c.History[:1], c.History[index+1:]...)
914
}
915
916
func (c *Client) updateHistory(response string) {
917
c.History = append(c.History, history.History{
918
Message: api.Message{
919
Role: AssistantRole,
920
Content: response,
921
},
922
Timestamp: c.timer.Now(),
923
})
924
925
if !c.Config.OmitHistory {
926
_ = c.historyStore.Write(c.History)
927
}
928
}
929
930
func (c *Client) base64Encode(path string) (string, error) {
931
imageData, err := c.reader.ReadFile(path)
932
if err != nil {
933
return "", err
934
}
935
936
return base64.StdEncoding.EncodeToString(imageData), nil
937
}
938
939
func (c *Client) createHistoryEntriesFromString(input string) []history.History {
940
var result []history.History
941
942
words := strings.Fields(input)
943
944
for i := 0; i < len(words); i += 100 {
945
end := i + 100
946
if end > len(words) {
947
end = len(words)
948
}
949
950
content := strings.Join(words[i:end], " ")
951
952
item := history.History{
953
Message: api.Message{
954
Role: UserRole,
955
Content: content,
956
},
957
Timestamp: c.timer.Now(),
958
}
959
result = append(result, item)
960
}
961
962
return result
963
}
964
965
func (c *Client) detectAudioFormat(path string) (string, error) {
966
file, err := c.reader.Open(path)
967
if err != nil {
968
return "", err
969
}
970
defer file.Close()
971
972
buf, err := c.reader.ReadBufferFromFile(file)
973
if err != nil {
974
return "", err
975
}
976
977
// WAV
978
if string(buf[0:4]) == "RIFF" && string(buf[8:12]) == "WAVE" {
979
return "wav", nil
980
}
981
982
// MP3 (ID3 or sync bits)
983
if string(buf[0:3]) == "ID3" || (buf[0] == 0xFF && (buf[1]&0xE0) == 0xE0) {
984
return "mp3", nil
985
}
986
987
// FLAC
988
if string(buf[0:4]) == "fLaC" {
989
return "flac", nil
990
}
991
992
// OGG
993
if string(buf[0:4]) == "OggS" {
994
return "ogg", nil
995
}
996
997
// M4A / MP4
998
if string(buf[4:8]) == "ftyp" {
999
if string(buf[8:12]) == "M4A " || string(buf[8:12]) == "isom" || string(buf[8:12]) == "mp42" {
1000
return "m4a", nil
1001
}
1002
return "mp4", nil
1003
}
1004
1005
return "unknown", nil
1006
}
1007
1008
func (c *Client) getMimeTypeFromFileContent(path string) (string, error) {
1009
file, err := c.reader.Open(path)
1010
if err != nil {
1011
return "", err
1012
}
1013
defer file.Close()
1014
1015
buffer, err := c.reader.ReadBufferFromFile(file)
1016
if err != nil {
1017
return "", err
1018
}
1019
1020
mimeType := stdhttp.DetectContentType(buffer)
1021
1022
return mimeType, nil
1023
}
1024
1025
func (c *Client) printRequestDebugInfo(endpoint string, body []byte, headers map[string]string) {
1026
sugar := zap.S()
1027
sugar.Debugf("\nGenerated cURL command:\n")
1028
1029
method := "POST"
1030
if body == nil {
1031
method = "GET"
1032
}
1033
sugar.Debugf("curl --location --insecure --request %s '%s' \\", method, endpoint)
1034
1035
if len(headers) > 0 {
1036
for k, v := range headers {
1037
sugar.Debugf(" --header '%s: %s' \\", k, v)
1038
}
1039
} else {
1040
sugar.Debugf(" --header \"%s: %s${%s_API_KEY}\" \\", c.Config.AuthHeader, c.Config.AuthTokenPrefix, strings.ToUpper(c.Config.Name))
1041
sugar.Debugf(" --header '%s: %s' \\", internal.HeaderContentTypeKey, internal.HeaderContentTypeValue)
1042
sugar.Debugf(" --header '%s: %s' \\", internal.HeaderUserAgentKey, c.Config.UserAgent)
1043
1044
// Include custom headers from config
1045
for k, v := range c.Config.CustomHeaders {
1046
sugar.Debugf(" --header '%s: %s' \\", k, v)
1047
}
1048
}
1049
1050
if body != nil {
1051
bodyString := strings.ReplaceAll(string(body), "'", "'\"'\"'")
1052
sugar.Debugf(" --data-raw '%s'", bodyString)
1053
}
1054
}
1055
1056
func (c *Client) printResponseDebugInfo(raw []byte) {
1057
sugar := zap.S()
1058
sugar.Debugf("\nResponse\n")
1059
sugar.Debugf("%s\n", raw)
1060
}
1061
1062
func (c *Client) postAndWriteBinaryOutput(endpoint string, requestBody interface{}, outputPath, debugLabel string, transform func([]byte) ([]byte, error)) error {
1063
body, err := json.Marshal(requestBody)
1064
if err != nil {
1065
return fmt.Errorf("failed to marshal request: %w", err)
1066
}
1067
1068
c.printRequestDebugInfo(endpoint, body, nil)
1069
1070
respBytes, err := c.caller.Post(endpoint, body, false)
1071
if err != nil {
1072
return fmt.Errorf("API request failed: %w", err)
1073
}
1074
1075
if transform != nil {
1076
respBytes, err = transform(respBytes)
1077
if err != nil {
1078
return err
1079
}
1080
}
1081
1082
outFile, err := c.writer.Create(outputPath)
1083
if err != nil {
1084
return fmt.Errorf("failed to create output file: %w", err)
1085
}
1086
defer outFile.Close()
1087
1088
if err := c.writer.Write(outFile, respBytes); err != nil {
1089
return fmt.Errorf("failed to write %s: %w", debugLabel, err)
1090
}
1091
1092
c.printResponseDebugInfo([]byte(fmt.Sprintf("[%s] %d bytes written to %s", debugLabel, len(respBytes), outputPath)))
1093
return nil
1094
}
1095
1096
func (c *Client) buildMCPRequest(mcp api.MCPRequest) (string, map[string]string, []byte, error) {
1097
mcp.Provider = strings.ToLower(mcp.Provider)
1098
params := mcp.Params
1099
1100
if mcp.Provider != utils.ApifyProvider {
1101
return "", nil, nil, errors.New(ErrUnsupportedProvider)
1102
}
1103
1104
apiKey := c.Config.ApifyAPIKey
1105
if apiKey == "" {
1106
return "", nil, nil, fmt.Errorf(ErrMissingMCPAPIKey, mcp.Provider)
1107
}
1108
1109
params[ApifyProxyConfig] = api.ProxyConfiguration{UseApifyProxy: true}
1110
endpoint := ApifyURL + mcp.Function + ApifyPath
1111
1112
headers := map[string]string{
1113
internal.HeaderContentTypeKey: internal.HeaderContentTypeValue,
1114
internal.HeaderAuthorizationKey: fmt.Sprintf("Bearer %s", apiKey),
1115
}
1116
1117
body, err := json.Marshal(params)
1118
if err != nil {
1119
return "", nil, nil, fmt.Errorf("failed to marshal request: %w", err)
1120
}
1121
1122
return endpoint, headers, body, nil
1123
}
1124
1125
type ModelCapabilities struct {
1126
SupportsTemperature bool
1127
SupportsStreaming bool
1128
UsesResponsesAPI bool
1129
OmitFirstSystemMsg bool
1130
}
1131
1132
func GetCapabilities(model string) ModelCapabilities {
1133
return ModelCapabilities{
1134
SupportsTemperature: !strings.Contains(model, SearchModelPattern),
1135
SupportsStreaming: !strings.Contains(model, o1ProPattern),
1136
UsesResponsesAPI: strings.Contains(model, o1ProPattern) || strings.Contains(model, gpt5Pattern),
1137
OmitFirstSystemMsg: strings.HasPrefix(model, o1Prefix) && !strings.Contains(model, o1ProPattern),
1138
}
1139
}
1140
1141
func formatMCPResponse(raw []byte, function string) string {
1142
var result interface{}
1143
if err := json.Unmarshal(raw, &result); err != nil {
1144
return fmt.Sprintf("[MCP: %s] (failed to decode response)", function)
1145
}
1146
1147
var lines []string
1148
1149
switch v := result.(type) {
1150
case []interface{}:
1151
if len(v) == 0 {
1152
return fmt.Sprintf("[MCP: %s] (no data returned)", function)
1153
}
1154
if obj, ok := v[0].(map[string]interface{}); ok {
1155
lines = formatKeyValues(obj)
1156
} else {
1157
return fmt.Sprintf("[MCP: %s] (unexpected response format)", function)
1158
}
1159
case map[string]interface{}:
1160
lines = formatKeyValues(v)
1161
default:
1162
return fmt.Sprintf("[MCP: %s] (unexpected response format)", function)
1163
}
1164
1165
sort.Strings(lines)
1166
return fmt.Sprintf("[MCP: %s]\n%s", function, strings.Join(lines, "\n"))
1167
}
1168
1169
func formatKeyValues(obj map[string]interface{}) []string {
1170
var lines []string
1171
caser := cases.Title(language.English)
1172
for k, val := range obj {
1173
label := caser.String(strings.ReplaceAll(k, "_", " "))
1174
lines = append(lines, fmt.Sprintf("%s: %v", label, val))
1175
}
1176
return lines
1177
}
1178
1179
func calculateEffectiveContextWindow(window int, bufferPercentage int) int {
1180
adjustedPercentage := 100 - bufferPercentage
1181
effectiveContextWindow := (window * adjustedPercentage) / 100
1182
return effectiveContextWindow
1183
}
1184
1185
func countTokens(entries []history.History) (int, []int) {
1186
var result int
1187
var rolling []int
1188
1189
for _, entry := range entries {
1190
charCount, wordCount := 0, 0
1191
words := strings.Fields(entry.Content.(string))
1192
wordCount += len(words)
1193
1194
for _, word := range words {
1195
charCount += utf8.RuneCountInString(word)
1196
}
1197
1198
// This is a simple approximation; actual token count may differ.
1199
// You can adjust this based on your language and the specific tokenizer used by the model.
1200
tokenCountForMessage := (charCount + wordCount) / 2
1201
result += tokenCountForMessage
1202
rolling = append(rolling, tokenCountForMessage)
1203
}
1204
1205
return result, rolling
1206
}
1207
1208
func getExtension(path string) string {
1209
ext := filepath.Ext(path) // e.g. ".mp4"
1210
if ext != "" {
1211
return strings.TrimPrefix(ext, ".") // "mp4"
1212
}
1213
return ""
1214
}
1215
1216
func getMimeTypeFromBytes(data []byte) (string, error) {
1217
mimeType := stdhttp.DetectContentType(data)
1218
1219
return mimeType, nil
1220
}
1221
1222
func isValidURL(input string) bool {
1223
parsedURL, err := url.ParseRequestURI(input)
1224
if err != nil {
1225
return false
1226
}
1227
1228
// Ensure that the URL has a valid scheme
1229
schemes := []string{httpScheme, httpsScheme}
1230
for _, scheme := range schemes {
1231
if strings.HasPrefix(parsedURL.Scheme, scheme) {
1232
return true
1233
}
1234
}
1235
1236
return false
1237
}
1238
1239