package llm import ( "ai-code-assistant/pkg/config" "bytes" "context" "embed" "fmt" "github.com/tmc/langchaingo/embeddings" "github.com/tmc/langchaingo/llms" "text/template" ) type contextKey string const contextKeyLLM contextKey = "llm" //go:embed prompts var prompts embed.FS type LLM struct { code llms.Model chat llms.Model embedder embeddings.Embedder } func FromConfig(cfg *config.Configuration) (*LLM, error) { embedLLM, err := cfg.Embedding.GetEmbedding() if err != nil { return nil, fmt.Errorf("unable to get embedding model: %w", err) } codeLLM, err := cfg.Code.GetModel() if err != nil { return nil, fmt.Errorf("unable to get coder model: %w", err) } chatLLM, err := cfg.Chat.GetModel() if err != nil { return nil, fmt.Errorf("unable to get chat model: %w", err) } embedder, err := embeddings.NewEmbedder(embedLLM) if err != nil { return nil, err } return &LLM{ embedder: embedder, code: codeLLM, chat: chatLLM, }, nil } func FromContext(ctx context.Context) *LLM { return ctx.Value(contextKeyLLM).(*LLM) } func WrapContext(ctx context.Context, llmRef *LLM) context.Context { return context.WithValue(ctx, contextKeyLLM, llmRef) } func (llm *LLM) GetEmbedding(ctx context.Context, texts ...string) ([][]float32, error) { return llm.embedder.EmbedDocuments(ctx, texts) } func (llm *LLM) Embedder() embeddings.Embedder { return llm.embedder } func (llm *LLM) CodePrompt(ctx context.Context, prompt string) (string, error) { return llm.code.Call(ctx, prompt) } func (llm *LLM) ChatPrompt(ctx context.Context, prompt string) (string, error) { return llm.chat.Call(ctx, prompt) } func GetPrompt(name string, data any) (string, error) { tmplText, err := prompts.ReadFile("prompts/" + name + ".tmpl") if err != nil { return "", err } tmpl, err := template.New(name).Parse(string(tmplText)) if err != nil { return "", err } var buf bytes.Buffer if err := tmpl.Execute(&buf, data); err != nil { return "", err } return buf.String(), nil }