Files
ai-code-assistant/pkg/llm/llm.go

101 lines
2.5 KiB
Go

package llm
import (
"ai-code-assistant/pkg/config"
"bytes"
"context"
"embed"
"fmt"
"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms"
"text/template"
)
type contextKey string
const contextKeyLLM contextKey = "llm"
//go:embed prompts
var prompts embed.FS
// LLM is responsible for abstracting the configuration and implementations of the LLMs used.
type LLM struct {
code llms.Model
chat llms.Model
embedder embeddings.Embedder
}
// FromConfig bootstraps the LLM from a passed in configuration.
func FromConfig(cfg *config.Configuration) (*LLM, error) {
embedLLM, err := cfg.Embedding.GetEmbedding()
if err != nil {
return nil, fmt.Errorf("unable to get embedding model: %w", err)
}
codeLLM, err := cfg.Code.GetModel()
if err != nil {
return nil, fmt.Errorf("unable to get coder model: %w", err)
}
chatLLM, err := cfg.Chat.GetModel()
if err != nil {
return nil, fmt.Errorf("unable to get chat model: %w", err)
}
embedder, err := embeddings.NewEmbedder(embedLLM)
if err != nil {
return nil, err
}
return &LLM{
embedder: embedder,
code: codeLLM,
chat: chatLLM,
}, nil
}
// FromContext retrieves an LLM from a passed in context wrapped with WrapContext.
func FromContext(ctx context.Context) *LLM {
return ctx.Value(contextKeyLLM).(*LLM)
}
// WrapContext embeds an LLM inside a context so it can be retrieved with FromContext.
func WrapContext(ctx context.Context, llmRef *LLM) context.Context {
return context.WithValue(ctx, contextKeyLLM, llmRef)
}
// Embedder gets an embedder that can be used to store and retrieve embeddings.
func (llm *LLM) Embedder() embeddings.Embedder {
return llm.embedder
}
// CodePrompt passes a prompt to the code LLM and returns the response.
func (llm *LLM) CodePrompt(ctx context.Context, prompt string) (string, error) {
return llm.code.Call(ctx, prompt)
}
// ChatPrompt passes a prompt to the chat LLM and returns the response.
func (llm *LLM) ChatPrompt(ctx context.Context, prompt string) (string, error) {
return llm.chat.Call(ctx, prompt)
}
// GetPrompt loads a LLM prompt template and injects variables into it. Uses the go template format.
func GetPrompt(name string, data any) (string, error) {
tmplText, err := prompts.ReadFile("prompts/" + name + ".tmpl")
if err != nil {
return "", err
}
tmpl, err := template.New(name).Parse(string(tmplText))
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}