214 lines
5.9 KiB
Go
214 lines
5.9 KiB
Go
package database
|
|
|
|
import (
|
|
"ai-code-assistant/pkg/config"
|
|
"context"
|
|
"embed"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/go-git/go-billy/v5/osfs"
|
|
"github.com/go-git/go-git/v5"
|
|
"github.com/go-git/go-git/v5/plumbing/cache"
|
|
"github.com/go-git/go-git/v5/storage/filesystem"
|
|
"github.com/golang-migrate/migrate/v4"
|
|
_ "github.com/golang-migrate/migrate/v4/database/postgres"
|
|
"github.com/golang-migrate/migrate/v4/source/iofs"
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
_ "github.com/lib/pq"
|
|
"github.com/tmc/langchaingo/embeddings"
|
|
"github.com/tmc/langchaingo/vectorstores"
|
|
"github.com/tmc/langchaingo/vectorstores/pgvector"
|
|
"path/filepath"
|
|
"strconv"
|
|
)
|
|
|
|
//go:embed migrations
|
|
var migrations embed.FS
|
|
|
|
type contextKey string
|
|
|
|
const (
|
|
contextKeyDB contextKey = "db"
|
|
)
|
|
|
|
func preparedStatements() map[string]string {
|
|
return map[string]string{
|
|
"get_repo": `SELECT repo_id FROM repos WHERE repo_hash = $1 AND repo_path = $2`,
|
|
"insert_repo": `INSERT INTO repos (repo_id, repo_hash, repo_path) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING`,
|
|
"get_chunk": `SELECT document FROM langchain_pg_embedding WHERE JSON_EXTRACT_PATH_TEXT(cmetadata, 'chunk_id')=$1 AND JSON_EXTRACT_PATH_TEXT(cmetadata, 'file_path')=$2 AND JSON_EXTRACT_PATH_TEXT(cmetadata, 'repo_id')=$3`,
|
|
"clear_chunks_for_repo": `DELETE FROM langchain_pg_embedding WHERE JSON_EXTRACT_PATH_TEXT(cmetadata, 'repo_id')=$1`,
|
|
}
|
|
}
|
|
|
|
// Database is a helper that abstracts access to the database as well as handles migrations.
|
|
type Database struct {
|
|
pool *pgxpool.Pool
|
|
}
|
|
|
|
// db gets a references to the connection pool.
|
|
func (db *Database) db(ctx context.Context) (*pgxpool.Conn, error) {
|
|
conn, err := db.pool.Acquire(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for nm, query := range preparedStatements() {
|
|
if _, err := conn.Conn().Prepare(ctx, nm, query); err != nil {
|
|
return nil, fmt.Errorf("problem preparing statement %s: %w", nm, err)
|
|
}
|
|
}
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
// UpsertRepo will ensure we have a repository reference for this particular repository path and HEAD reference.
|
|
func (db *Database) UpsertRepo(ctx context.Context, repoPath string) (string, bool, error) {
|
|
gitPath := osfs.New(filepath.Join(repoPath, ".git"))
|
|
|
|
gitRepo, err := git.Open(filesystem.NewStorage(gitPath, cache.NewObjectLRUDefault()), gitPath)
|
|
if err != nil {
|
|
return "", false, err
|
|
}
|
|
|
|
headRef, err := gitRepo.Head()
|
|
if err != nil {
|
|
return "", false, err
|
|
}
|
|
|
|
conn, err := db.db(ctx)
|
|
if err != nil {
|
|
return "", false, err
|
|
}
|
|
defer conn.Release()
|
|
|
|
var id string
|
|
|
|
if err := conn.QueryRow(ctx, "get_repo", headRef.Hash().String(), repoPath).Scan(&id); err == nil {
|
|
return id, true, nil
|
|
} else if !errors.Is(err, pgx.ErrNoRows) {
|
|
return "", false, err
|
|
}
|
|
|
|
id = uuid.NewString()
|
|
|
|
if _, err := conn.Exec(ctx, "insert_repo", id, headRef.Hash().String(), repoPath); err != nil {
|
|
return "", false, err
|
|
}
|
|
|
|
return id, false, nil
|
|
}
|
|
|
|
// GetChunk will get a specified chunk from the database.
|
|
func (db *Database) GetChunk(ctx context.Context, chunkID int, path, repoID string) (string, error) {
|
|
conn, err := db.db(ctx)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer conn.Release()
|
|
|
|
var chunk string
|
|
|
|
chunkIDStr := strconv.FormatInt(int64(chunkID), 10)
|
|
|
|
if err := conn.QueryRow(ctx, "get_chunk", chunkIDStr, path, repoID).Scan(&chunk); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return chunk, nil
|
|
}
|
|
|
|
// ClearChunkIndex will clear all embeddings for a particular repo ID.
|
|
func (db *Database) ClearChunkIndex(ctx context.Context, repoID string) error {
|
|
conn, err := db.db(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer conn.Release()
|
|
|
|
if _, err := conn.Exec(ctx, "clear_chunks_for_repo", repoID); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetVectorStore gets a vector store capable of storing and searching embeddings.
|
|
func (db *Database) GetVectorStore(ctx context.Context, embedder embeddings.Embedder) (vectorstores.VectorStore, func(), error) {
|
|
conn, err := db.db(ctx)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
vectorStore, err := pgvector.New(ctx,
|
|
pgvector.WithConn(conn),
|
|
pgvector.WithEmbedder(embedder),
|
|
pgvector.WithCollectionName("file_chunks"),
|
|
)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
return vectorStore, func() {
|
|
conn.Release()
|
|
}, nil
|
|
}
|
|
|
|
// GetChunkContext will get a specified chunk and surrounding context based on some amount of distance.
|
|
func (db *Database) GetChunkContext(ctx context.Context, chunkID, distance int, path, repoID string) (string, error) {
|
|
minChunk := chunkID - distance
|
|
if minChunk < 0 {
|
|
minChunk = 0
|
|
}
|
|
|
|
chunkContext := ""
|
|
|
|
for chunkID := minChunk; chunkID < chunkID+(distance*2); chunkID++ {
|
|
chunk, err := db.GetChunk(ctx, chunkID, path, repoID)
|
|
if err == nil {
|
|
chunkContext += chunk
|
|
} else if !errors.Is(err, pgx.ErrNoRows) {
|
|
return "", err
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
|
|
return chunkContext, nil
|
|
}
|
|
|
|
// FromConfig generates a new Database from a passed in configuration. Used for bootstrapping.
|
|
func FromConfig(ctx context.Context, cfg *config.Configuration) (*Database, error) {
|
|
migFS, err := iofs.New(migrations, "migrations")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
mig, err := migrate.NewWithSourceInstance("iofs", migFS, cfg.Database.ConnString)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := mig.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
|
return nil, fmt.Errorf("problem performing database migrations: %w", err)
|
|
}
|
|
|
|
pool, err := pgxpool.New(ctx, cfg.Database.ConnString)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Database{pool: pool}, nil
|
|
}
|
|
|
|
// FromContext will retrieve the Database from a Context where WrapContext was used to embed the database in the context.
|
|
func FromContext(ctx context.Context) *Database {
|
|
return ctx.Value(contextKeyDB).(*Database)
|
|
}
|
|
|
|
// WrapContext embeds a Database inside a Context so it can be passed to functions.
|
|
func WrapContext(ctx context.Context, cfg *Database) context.Context {
|
|
return context.WithValue(ctx, contextKeyDB, cfg)
|
|
}
|