TTW_Bot_GO/internal/ai/gigachat.go

170 lines
4.6 KiB
Go

package ai
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
type gigaAuthResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
}
type gigaMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type gigaChatRequest struct {
Model string `json:"model"`
Messages []gigaMessage `json:"messages"`
Stream bool `json:"stream"`
}
type gigaChatResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
type GigaChatProvider struct {
clientID string
authBasic string // уже готовый base64(clientID:secret)
endpoint string
model string
systemPrompt string
httpClient *http.Client
accessToken string
tokenExpiry time.Time
}
func NewGigaChatProvider(clientID, authBasic, endpoint, model, systemPrompt string) *GigaChatProvider {
if endpoint == "" {
endpoint = "https://gigachat.devices.sberbank.ru/api/v1"
}
if model == "" {
model = "GigaChat"
}
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
return &GigaChatProvider{
clientID: strings.TrimSpace(clientID),
authBasic: strings.TrimSpace(authBasic),
endpoint: endpoint,
model: model,
systemPrompt: systemPrompt,
httpClient: &http.Client{Transport: tr, Timeout: 60 * time.Second},
}
}
func (p *GigaChatProvider) getToken(ctx context.Context) (string, error) {
if p.accessToken != "" && time.Now().Before(p.tokenExpiry) {
return p.accessToken, nil
}
authURL := "https://ngw.devices.sberbank.ru:9443/api/v2/oauth"
bodyData := "scope=GIGACHAT_API_PERS"
req, err := http.NewRequestWithContext(ctx, "POST", authURL, strings.NewReader(bodyData))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Basic "+p.authBasic)
req.Header.Set("RqUID", p.clientID)
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36")
resp, err := p.httpClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("gigachat auth error: %d %s", resp.StatusCode, string(bodyBytes))
}
var authResp gigaAuthResponse
if err := json.Unmarshal(bodyBytes, &authResp); err != nil {
return "", err
}
p.accessToken = authResp.AccessToken
p.tokenExpiry = time.Now().Add(time.Duration(authResp.ExpiresIn-60) * time.Second)
return p.accessToken, nil
}
func (p *GigaChatProvider) Ask(ctx context.Context, prompt string) (string, error) {
token, err := p.getToken(ctx)
if err != nil {
return "", err
}
messages := []gigaMessage{
{Role: "system", Content: p.systemPrompt},
{Role: "user", Content: prompt},
}
reqBody := gigaChatRequest{
Model: p.model,
Messages: messages,
Stream: false,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return "", err
}
chatURL := p.endpoint + "/chat/completions"
req, err := http.NewRequestWithContext(ctx, "POST", chatURL, bytes.NewReader(jsonData))
if err != nil {
return "", err
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("RqUID", p.clientID)
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36")
resp, err := p.httpClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("gigachat api error: %d %s", resp.StatusCode, string(bodyBytes))
}
var chatResp gigaChatResponse
if err := json.Unmarshal(bodyBytes, &chatResp); err != nil {
return "", err
}
if len(chatResp.Choices) == 0 {
return "", fmt.Errorf("no response from gigachat")
}
return chatResp.Choices[0].Message.Content, nil
}