jinshan/pkg/ai/alBaiLian.go
2025-06-19 10:35:26 +08:00

168 lines
4.6 KiB
Go

package ai
import (
"context"
"fmt"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"sync"
)
type BaichuanClient struct {
client *openai.Client
mu sync.Mutex
}
// NewBaichuanClient 创建新的百炼客户端实例
func NewBaichuanClient(apiKey string) *BaichuanClient {
return &BaichuanClient{
client: openai.NewClient(
option.WithAPIKey(apiKey),
option.WithBaseURL("https://dashscope.aliyuncs.com/compatible-mode/v1/"),
),
}
}
// ChatMessage 表示聊天消息的结构
type ChatMessage struct {
Role string `json:"role"` // 角色: system, user, assistant
Content string `json:"content"` // 消息内容
}
// ChatCompletionParams 定义聊天完成请求的参数
type ChatCompletionParams struct {
Model string `json:"model"` // 使用的模型 (如 qwen-plus, qwen-max)
Messages []ChatMessage `json:"messages"` // 聊天消息列表
MaxToken int `json:"maxToken"` // 最大返回token数 (可选)
Stream bool `json:"stream"` // 是否使用流式响应 (可选)
}
// ChatCompletionResponse 定义聊天完成的响应结构
type ChatCompletionResponse struct {
ID string `json:"id"` // 请求ID
Content string `json:"content"` // 返回内容
Model string `json:"model"` // 使用的模型
Usage struct {
PromptTokens int `json:"promptTokens"` // 提示token数
CompletionTokens int `json:"completionTokens"` // 完成token数
TotalTokens int `json:"totalTokens"` // 总token数
} `json:"usage"`
}
// ChatCompletion 调用阿里百炼智能体进行聊天完成
func (bc *BaichuanClient) ChatCompletion(ctx context.Context, params ChatCompletionParams) (*ChatCompletionResponse, error) {
bc.mu.Lock()
defer bc.mu.Unlock()
// 转换消息格式
openaiMessages := make([]openai.ChatCompletionMessageParamUnion, len(params.Messages))
for i, msg := range params.Messages {
switch msg.Role {
case "system":
openaiMessages[i] = openai.SystemMessage(msg.Content)
case "assistant":
openaiMessages[i] = openai.AssistantMessage(msg.Content)
default: // 默认为用户消息
openaiMessages[i] = openai.UserMessage(msg.Content)
}
}
// 设置请求参数
requestParams := openai.ChatCompletionNewParams{
Messages: openai.F(openaiMessages),
Model: openai.F(params.Model),
}
//// 设置可选参数
//if params.MaxToken > 0 {
// requestParams.MaxTokens = openai.F(params.MaxToken)
//}
//if params.Stream {
// requestParams.Stream = openai.F(true)
//}
// 调用API
chatCompletion, err := bc.client.Chat.Completions.New(ctx, requestParams)
if err != nil {
return nil, fmt.Errorf("API调用失败: %w", err)
}
// 处理响应
response := &ChatCompletionResponse{
ID: chatCompletion.ID,
Model: chatCompletion.Model,
}
if len(chatCompletion.Choices) > 0 {
response.Content = chatCompletion.Choices[0].Message.Content
}
//// 添加使用量统计
//if chatCompletion.Usage != nil {
// response.Usage.PromptTokens = chatCompletion.Usage.PromptTokens
// response.Usage.CompletionTokens = chatCompletion.Usage.CompletionTokens
// response.Usage.TotalTokens = chatCompletion.Usage.TotalTokens
//}
return response, nil
}
// Conversation 表示一个对话会话
type Conversation struct {
Messages []ChatMessage `json:"messages"` // 对话历史
Model string `json:"model"` // 使用的模型
Client *BaichuanClient `json:"-"` // 客户端实例
}
// NewConversation 创建新的对话会话
func (bc *BaichuanClient) NewConversation(model string, systemPrompt string) *Conversation {
messages := []ChatMessage{}
if systemPrompt != "" {
messages = append(messages, ChatMessage{Role: "system", Content: systemPrompt})
}
return &Conversation{
Messages: messages,
Model: model,
Client: bc,
}
}
// AddMessage 添加消息到对话历史
func (c *Conversation) AddMessage(role, content string) {
c.Messages = append(c.Messages, ChatMessage{
Role: role,
Content: content,
})
}
// GetResponse 获取AI对最新消息的回复
func (c *Conversation) GetResponse(ctx context.Context, maxToken int) (string, error) {
params := ChatCompletionParams{
Model: c.Model,
Messages: c.Messages,
MaxToken: maxToken,
}
response, err := c.Client.ChatCompletion(ctx, params)
if err != nil {
return "", err
}
// 将AI回复添加到对话历史
c.AddMessage("assistant", response.Content)
return response.Content, nil
}
// ClearHistory 清除对话历史,但保留系统提示
func (c *Conversation) ClearHistory() {
for i := len(c.Messages) - 1; i >= 0; i-- {
if c.Messages[i].Role == "system" {
c.Messages = c.Messages[:i+1]
return
}
}
c.Messages = []ChatMessage{}
}