170 lines
4.6 KiB
Go
170 lines
4.6 KiB
Go
package ai
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type gigaAuthResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
TokenType string `json:"token_type"`
|
|
}
|
|
|
|
type gigaMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type gigaChatRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []gigaMessage `json:"messages"`
|
|
Stream bool `json:"stream"`
|
|
}
|
|
|
|
type gigaChatResponse struct {
|
|
Choices []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
} `json:"choices"`
|
|
}
|
|
|
|
type GigaChatProvider struct {
|
|
clientID string
|
|
authBasic string // уже готовый base64(clientID:secret)
|
|
endpoint string
|
|
model string
|
|
systemPrompt string
|
|
httpClient *http.Client
|
|
accessToken string
|
|
tokenExpiry time.Time
|
|
}
|
|
|
|
func NewGigaChatProvider(clientID, authBasic, endpoint, model, systemPrompt string) *GigaChatProvider {
|
|
if endpoint == "" {
|
|
endpoint = "https://gigachat.devices.sberbank.ru/api/v1"
|
|
}
|
|
if model == "" {
|
|
model = "GigaChat"
|
|
}
|
|
tr := &http.Transport{
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
}
|
|
return &GigaChatProvider{
|
|
clientID: strings.TrimSpace(clientID),
|
|
authBasic: strings.TrimSpace(authBasic),
|
|
endpoint: endpoint,
|
|
model: model,
|
|
systemPrompt: systemPrompt,
|
|
httpClient: &http.Client{Transport: tr, Timeout: 60 * time.Second},
|
|
}
|
|
}
|
|
|
|
func (p *GigaChatProvider) getToken(ctx context.Context) (string, error) {
|
|
if p.accessToken != "" && time.Now().Before(p.tokenExpiry) {
|
|
return p.accessToken, nil
|
|
}
|
|
|
|
authURL := "https://ngw.devices.sberbank.ru:9443/api/v2/oauth"
|
|
bodyData := "scope=GIGACHAT_API_PERS"
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", authURL, strings.NewReader(bodyData))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.Header.Set("Accept", "application/json")
|
|
req.Header.Set("Authorization", "Basic "+p.authBasic)
|
|
req.Header.Set("RqUID", p.clientID)
|
|
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36")
|
|
|
|
resp, err := p.httpClient.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
bodyBytes, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return "", fmt.Errorf("gigachat auth error: %d %s", resp.StatusCode, string(bodyBytes))
|
|
}
|
|
|
|
var authResp gigaAuthResponse
|
|
if err := json.Unmarshal(bodyBytes, &authResp); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
p.accessToken = authResp.AccessToken
|
|
p.tokenExpiry = time.Now().Add(time.Duration(authResp.ExpiresIn-60) * time.Second)
|
|
return p.accessToken, nil
|
|
}
|
|
|
|
func (p *GigaChatProvider) Ask(ctx context.Context, prompt string) (string, error) {
|
|
token, err := p.getToken(ctx)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
messages := []gigaMessage{
|
|
{Role: "system", Content: p.systemPrompt},
|
|
{Role: "user", Content: prompt},
|
|
}
|
|
|
|
reqBody := gigaChatRequest{
|
|
Model: p.model,
|
|
Messages: messages,
|
|
Stream: false,
|
|
}
|
|
jsonData, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
chatURL := p.endpoint + "/chat/completions"
|
|
req, err := http.NewRequestWithContext(ctx, "POST", chatURL, bytes.NewReader(jsonData))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept", "application/json")
|
|
req.Header.Set("RqUID", p.clientID)
|
|
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36")
|
|
|
|
resp, err := p.httpClient.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
bodyBytes, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return "", fmt.Errorf("gigachat api error: %d %s", resp.StatusCode, string(bodyBytes))
|
|
}
|
|
|
|
var chatResp gigaChatResponse
|
|
if err := json.Unmarshal(bodyBytes, &chatResp); err != nil {
|
|
return "", err
|
|
}
|
|
if len(chatResp.Choices) == 0 {
|
|
return "", fmt.Errorf("no response from gigachat")
|
|
}
|
|
return chatResp.Choices[0].Message.Content, nil
|
|
} |