Files
ai-code-assistant/pkg/config/config.go

110 lines
2.4 KiB
Go

package config
import (
"context"
"fmt"
"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/ollama"
"github.com/tmc/langchaingo/llms/openai"
)
type contextKey string
const (
contextKeyConfig contextKey = "config"
)
// Configuration is a simple configuration that can be loaded from a YAML file.
type Configuration struct {
Database struct {
ConnString string `yaml:"conn_string"`
} `yaml:"database"`
Logging struct {
Level string `yaml:"level"`
} `yaml:"logging"`
Embedding LLMConfig `yaml:"embedding"`
Code LLMConfig `yaml:"code"`
Chat LLMConfig `yaml:"chat"`
RelevantDocs int `yaml:"relevant_docs"`
IndexChunkSize int `yaml:"index_chunk_size"`
}
type LLMConfig struct {
Type string `yaml:"type"`
Model string `yaml:"model"`
OLlama OLlamaConfig `yaml:"ollama"`
OpenAI OpenAIConfig `yaml:"openai"`
}
type OLlamaConfig struct {
URL string `yaml:"url"`
}
type OpenAIConfig struct {
Key string `yaml:"key"`
URL string `yaml:"url"`
}
func (cfg LLMConfig) GetModel() (llms.Model, error) {
switch cfg.Type {
case "ollama":
var opts []ollama.Option
if cfg.Model != "" {
opts = append(opts, ollama.WithModel(cfg.Model))
}
if cfg.OLlama.URL != "" {
opts = append(opts, ollama.WithServerURL(cfg.OLlama.URL))
}
return ollama.New(opts...)
case "openai":
var opts []openai.Option
if cfg.Model != "" {
opts = append(opts, openai.WithModel(cfg.Model))
}
if cfg.OpenAI.URL != "" {
opts = append(opts, openai.WithBaseURL(cfg.OpenAI.URL))
}
if cfg.OpenAI.Key != "" {
opts = append(opts, openai.WithToken(cfg.OpenAI.Key))
}
return openai.New(opts...)
default:
return nil, fmt.Errorf("unknown model type: %s", cfg.Type)
}
}
func (cfg LLMConfig) GetEmbedding() (embeddings.EmbedderClient, error) {
switch cfg.Type {
case "ollama":
var opts []ollama.Option
if cfg.Model != "" {
opts = append(opts, ollama.WithModel(cfg.Model))
}
if cfg.OLlama.URL != "" {
opts = append(opts, ollama.WithServerURL(cfg.OLlama.URL))
}
return ollama.New(opts...)
default:
return nil, fmt.Errorf("unknown embedding type: %s", cfg.Type)
}
}
func FromContext(ctx context.Context) *Configuration {
return ctx.Value(contextKeyConfig).(*Configuration)
}
func WrapContext(ctx context.Context, cfg *Configuration) context.Context {
return context.WithValue(ctx, contextKeyConfig, cfg)
}