залил
This commit is contained in:
@@ -0,0 +1,16 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type ChatGPTProvider struct{}
|
||||
|
||||
func NewChatGPTProvider(apiKey, model, systemPrompt string) *ChatGPTProvider {
|
||||
return &ChatGPTProvider{}
|
||||
}
|
||||
|
||||
func (p *ChatGPTProvider) Ask(ctx context.Context, prompt string) (string, error) {
|
||||
return "", fmt.Errorf("ChatGPT not implemented yet")
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"stream-bot/internal/db"
|
||||
)
|
||||
|
||||
func NewProvider(cfg *db.AIConfig) (Provider, error) {
|
||||
switch cfg.Provider {
|
||||
case "ollama":
|
||||
return NewOllamaProvider(cfg.Endpoint, cfg.Model, cfg.SystemPrompt), nil
|
||||
case "chatgpt":
|
||||
return NewChatGPTProvider(cfg.APIKey, cfg.Model, cfg.SystemPrompt), nil
|
||||
case "gigachat":
|
||||
if cfg.ClientID == "" || cfg.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("client_id and client_secret required for GigaChat")
|
||||
}
|
||||
return NewGigaChatProvider(cfg.ClientID, cfg.ClientSecret, cfg.Endpoint, cfg.Model, cfg.SystemPrompt), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown provider: %s", cfg.Provider)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type gigaAuthResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
type gigaMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type gigaChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []gigaMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type gigaChatResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
type GigaChatProvider struct {
|
||||
clientID string
|
||||
authBasic string // уже готовый base64(clientID:secret)
|
||||
endpoint string
|
||||
model string
|
||||
systemPrompt string
|
||||
httpClient *http.Client
|
||||
accessToken string
|
||||
tokenExpiry time.Time
|
||||
}
|
||||
|
||||
func NewGigaChatProvider(clientID, authBasic, endpoint, model, systemPrompt string) *GigaChatProvider {
|
||||
if endpoint == "" {
|
||||
endpoint = "https://gigachat.devices.sberbank.ru/api/v1"
|
||||
}
|
||||
if model == "" {
|
||||
model = "GigaChat"
|
||||
}
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
return &GigaChatProvider{
|
||||
clientID: strings.TrimSpace(clientID),
|
||||
authBasic: strings.TrimSpace(authBasic),
|
||||
endpoint: endpoint,
|
||||
model: model,
|
||||
systemPrompt: systemPrompt,
|
||||
httpClient: &http.Client{Transport: tr, Timeout: 60 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GigaChatProvider) getToken(ctx context.Context) (string, error) {
|
||||
if p.accessToken != "" && time.Now().Before(p.tokenExpiry) {
|
||||
return p.accessToken, nil
|
||||
}
|
||||
|
||||
authURL := "https://ngw.devices.sberbank.ru:9443/api/v2/oauth"
|
||||
bodyData := "scope=GIGACHAT_API_PERS"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", authURL, strings.NewReader(bodyData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Basic "+p.authBasic)
|
||||
req.Header.Set("RqUID", p.clientID)
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36")
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
_ = Body.Close()
|
||||
}(resp.Body)
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("gigachat auth error: %d %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var authResp gigaAuthResponse
|
||||
if err := json.Unmarshal(bodyBytes, &authResp); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
p.accessToken = authResp.AccessToken
|
||||
p.tokenExpiry = time.Now().Add(time.Duration(authResp.ExpiresIn-60) * time.Second)
|
||||
return p.accessToken, nil
|
||||
}
|
||||
|
||||
func (p *GigaChatProvider) Ask(ctx context.Context, prompt string) (string, error) {
|
||||
token, err := p.getToken(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
messages := []gigaMessage{
|
||||
{Role: "system", Content: p.systemPrompt},
|
||||
{Role: "user", Content: prompt},
|
||||
}
|
||||
|
||||
reqBody := gigaChatRequest{
|
||||
Model: p.model,
|
||||
Messages: messages,
|
||||
Stream: false,
|
||||
}
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
chatURL := p.endpoint + "/chat/completions"
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", chatURL, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("RqUID", p.clientID)
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36")
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
_ = Body.Close()
|
||||
}(resp.Body)
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("gigachat api error: %d %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var chatResp gigaChatResponse
|
||||
if err := json.Unmarshal(bodyBytes, &chatResp); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(chatResp.Choices) == 0 {
|
||||
return "", fmt.Errorf("no response from gigachat")
|
||||
}
|
||||
return chatResp.Choices[0].Message.Content, nil
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OllamaProvider struct {
|
||||
endpoint string
|
||||
model string
|
||||
systemPrompt string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
type ollamaRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
System string `json:"system,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type ollamaResponse struct {
|
||||
Response string `json:"response"`
|
||||
}
|
||||
|
||||
func NewOllamaProvider(endpoint, model, systemPrompt string) *OllamaProvider {
|
||||
if endpoint == "" {
|
||||
endpoint = "http://localhost:11434"
|
||||
}
|
||||
return &OllamaProvider{
|
||||
endpoint: endpoint,
|
||||
model: model,
|
||||
systemPrompt: systemPrompt,
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OllamaProvider) Ask(ctx context.Context, prompt string) (string, error) {
|
||||
reqBody := ollamaRequest{
|
||||
Model: p.model,
|
||||
Prompt: prompt,
|
||||
System: p.systemPrompt,
|
||||
Stream: false,
|
||||
}
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
url := p.endpoint + "/api/generate"
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
_ = Body.Close()
|
||||
}(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("ollama error: %d", resp.StatusCode)
|
||||
}
|
||||
var ollamaResp ollamaResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&ollamaResp); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return ollamaResp.Response, nil
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package ai
|
||||
|
||||
import "context"
|
||||
|
||||
type Provider interface {
|
||||
Ask(ctx context.Context, prompt string) (string, error)
|
||||
}
|
||||
Reference in New Issue
Block a user