168 lines
4.6 KiB
Go
168 lines
4.6 KiB
Go
package ai
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"github.com/openai/openai-go"
|
|
"github.com/openai/openai-go/option"
|
|
"sync"
|
|
)
|
|
|
|
type BaichuanClient struct {
|
|
client *openai.Client
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// NewBaichuanClient 创建新的百炼客户端实例
|
|
func NewBaichuanClient(apiKey string) *BaichuanClient {
|
|
return &BaichuanClient{
|
|
client: openai.NewClient(
|
|
option.WithAPIKey(apiKey),
|
|
option.WithBaseURL("https://dashscope.aliyuncs.com/compatible-mode/v1/"),
|
|
),
|
|
}
|
|
}
|
|
|
|
// ChatMessage 表示聊天消息的结构
|
|
type ChatMessage struct {
|
|
Role string `json:"role"` // 角色: system, user, assistant
|
|
Content string `json:"content"` // 消息内容
|
|
}
|
|
|
|
// ChatCompletionParams 定义聊天完成请求的参数
|
|
type ChatCompletionParams struct {
|
|
Model string `json:"model"` // 使用的模型 (如 qwen-plus, qwen-max)
|
|
Messages []ChatMessage `json:"messages"` // 聊天消息列表
|
|
MaxToken int `json:"maxToken"` // 最大返回token数 (可选)
|
|
Stream bool `json:"stream"` // 是否使用流式响应 (可选)
|
|
}
|
|
|
|
// ChatCompletionResponse 定义聊天完成的响应结构
|
|
type ChatCompletionResponse struct {
|
|
ID string `json:"id"` // 请求ID
|
|
Content string `json:"content"` // 返回内容
|
|
Model string `json:"model"` // 使用的模型
|
|
Usage struct {
|
|
PromptTokens int `json:"promptTokens"` // 提示token数
|
|
CompletionTokens int `json:"completionTokens"` // 完成token数
|
|
TotalTokens int `json:"totalTokens"` // 总token数
|
|
} `json:"usage"`
|
|
}
|
|
|
|
// ChatCompletion 调用阿里百炼智能体进行聊天完成
|
|
func (bc *BaichuanClient) ChatCompletion(ctx context.Context, params ChatCompletionParams) (*ChatCompletionResponse, error) {
|
|
bc.mu.Lock()
|
|
defer bc.mu.Unlock()
|
|
|
|
// 转换消息格式
|
|
openaiMessages := make([]openai.ChatCompletionMessageParamUnion, len(params.Messages))
|
|
for i, msg := range params.Messages {
|
|
switch msg.Role {
|
|
case "system":
|
|
openaiMessages[i] = openai.SystemMessage(msg.Content)
|
|
case "assistant":
|
|
openaiMessages[i] = openai.AssistantMessage(msg.Content)
|
|
default: // 默认为用户消息
|
|
openaiMessages[i] = openai.UserMessage(msg.Content)
|
|
}
|
|
}
|
|
|
|
// 设置请求参数
|
|
requestParams := openai.ChatCompletionNewParams{
|
|
Messages: openai.F(openaiMessages),
|
|
Model: openai.F(params.Model),
|
|
}
|
|
|
|
//// 设置可选参数
|
|
//if params.MaxToken > 0 {
|
|
// requestParams.MaxTokens = openai.F(params.MaxToken)
|
|
//}
|
|
//if params.Stream {
|
|
// requestParams.Stream = openai.F(true)
|
|
//}
|
|
|
|
// 调用API
|
|
chatCompletion, err := bc.client.Chat.Completions.New(ctx, requestParams)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("API调用失败: %w", err)
|
|
}
|
|
|
|
// 处理响应
|
|
response := &ChatCompletionResponse{
|
|
ID: chatCompletion.ID,
|
|
Model: chatCompletion.Model,
|
|
}
|
|
|
|
if len(chatCompletion.Choices) > 0 {
|
|
response.Content = chatCompletion.Choices[0].Message.Content
|
|
}
|
|
|
|
//// 添加使用量统计
|
|
//if chatCompletion.Usage != nil {
|
|
// response.Usage.PromptTokens = chatCompletion.Usage.PromptTokens
|
|
// response.Usage.CompletionTokens = chatCompletion.Usage.CompletionTokens
|
|
// response.Usage.TotalTokens = chatCompletion.Usage.TotalTokens
|
|
//}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// Conversation 表示一个对话会话
|
|
type Conversation struct {
|
|
Messages []ChatMessage `json:"messages"` // 对话历史
|
|
Model string `json:"model"` // 使用的模型
|
|
Client *BaichuanClient `json:"-"` // 客户端实例
|
|
}
|
|
|
|
// NewConversation 创建新的对话会话
|
|
func (bc *BaichuanClient) NewConversation(model string, systemPrompt string) *Conversation {
|
|
messages := []ChatMessage{}
|
|
if systemPrompt != "" {
|
|
messages = append(messages, ChatMessage{Role: "system", Content: systemPrompt})
|
|
}
|
|
|
|
return &Conversation{
|
|
Messages: messages,
|
|
Model: model,
|
|
Client: bc,
|
|
}
|
|
}
|
|
|
|
// AddMessage 添加消息到对话历史
|
|
func (c *Conversation) AddMessage(role, content string) {
|
|
c.Messages = append(c.Messages, ChatMessage{
|
|
Role: role,
|
|
Content: content,
|
|
})
|
|
}
|
|
|
|
// GetResponse 获取AI对最新消息的回复
|
|
func (c *Conversation) GetResponse(ctx context.Context, maxToken int) (string, error) {
|
|
params := ChatCompletionParams{
|
|
Model: c.Model,
|
|
Messages: c.Messages,
|
|
MaxToken: maxToken,
|
|
}
|
|
|
|
response, err := c.Client.ChatCompletion(ctx, params)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// 将AI回复添加到对话历史
|
|
c.AddMessage("assistant", response.Content)
|
|
|
|
return response.Content, nil
|
|
}
|
|
|
|
// ClearHistory 清除对话历史,但保留系统提示
|
|
func (c *Conversation) ClearHistory() {
|
|
for i := len(c.Messages) - 1; i >= 0; i-- {
|
|
if c.Messages[i].Role == "system" {
|
|
c.Messages = c.Messages[:i+1]
|
|
return
|
|
}
|
|
}
|
|
c.Messages = []ChatMessage{}
|
|
}
|