Files
ai-code-assistant/pkg/database/database.go
Michael Powers 25f8cae8cb Code Cleanup and Quality of Life
Checks to make sure repo is indexed before generating code.
Don't generate tests for changes to tests.
Remove unused code.
Fix bootstrapping issue with langchaingo tables.
2025-04-20 08:31:26 -04:00

167 lines
4.2 KiB
Go

package database
import (
"ai-code-assistant/pkg/config"
"context"
"embed"
"errors"
"fmt"
"github.com/go-git/go-billy/v5/osfs"
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing/cache"
"github.com/go-git/go-git/v5/storage/filesystem"
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/postgres"
"github.com/golang-migrate/migrate/v4/source/iofs"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
_ "github.com/lib/pq"
"path/filepath"
"strconv"
)
//go:embed migrations
var migrations embed.FS
type contextKey string
const (
contextKeyDB contextKey = "db"
)
func preparedStatements() map[string]string {
return map[string]string{
"get_repo": `SELECT repo_id FROM repos WHERE repo_hash = $1 AND repo_path = $2`,
"insert_repo": `INSERT INTO repos (repo_id, repo_hash, repo_path) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING`,
"get_chunk": `SELECT document FROM langchain_pg_embedding WHERE JSON_EXTRACT_PATH_TEXT(cmetadata, 'chunk_id')=$1 AND JSON_EXTRACT_PATH_TEXT(cmetadata, 'file_path')=$2 AND JSON_EXTRACT_PATH_TEXT(cmetadata, 'repo_id')=$3`,
"clear_chunks_for_repo": `DELETE FROM langchain_pg_embedding WHERE JSON_EXTRACT_PATH_TEXT(cmetadata, 'repo_id')=$1`,
}
}
type Database struct {
pool *pgxpool.Pool
}
func (db *Database) DB(ctx context.Context) (*pgxpool.Conn, error) {
conn, err := db.pool.Acquire(ctx)
if err != nil {
return nil, err
}
for nm, query := range preparedStatements() {
if _, err := conn.Conn().Prepare(ctx, nm, query); err != nil {
return nil, fmt.Errorf("problem preparing statement %s: %w", nm, err)
}
}
return conn, nil
}
func (db *Database) UpsertRepo(ctx context.Context, repoPath string) (string, bool, error) {
gitPath := osfs.New(filepath.Join(repoPath, ".git"))
gitRepo, err := git.Open(filesystem.NewStorage(gitPath, cache.NewObjectLRUDefault()), gitPath)
if err != nil {
return "", false, err
}
headRef, err := gitRepo.Head()
if err != nil {
return "", false, err
}
conn, err := db.DB(ctx)
if err != nil {
return "", false, err
}
defer conn.Release()
var id string
if err := conn.QueryRow(ctx, "get_repo", headRef.Hash().String(), repoPath).Scan(&id); err == nil {
return id, true, nil
} else if !errors.Is(err, pgx.ErrNoRows) {
return "", false, err
}
id = uuid.NewString()
if _, err := conn.Exec(ctx, "insert_repo", id, headRef.Hash().String(), repoPath); err != nil {
return "", false, err
}
return id, false, nil
}
func (db *Database) GetChunk(ctx context.Context, chunkID int, path, repoID string) (string, error) {
conn, err := db.DB(ctx)
if err != nil {
return "", err
}
defer conn.Release()
var chunk string
chunkIDStr := strconv.FormatInt(int64(chunkID), 10)
if err := conn.QueryRow(ctx, "get_chunk", chunkIDStr, path, repoID).Scan(&chunk); err != nil {
return "", err
}
return chunk, nil
}
func (db *Database) GetChunkContext(ctx context.Context, chunkID, distance int, path, repoID string) (string, error) {
minChunk := chunkID - distance
if minChunk < 0 {
minChunk = 0
}
chunkContext := ""
for chunkID := minChunk; chunkID < chunkID+(distance*2); chunkID++ {
chunk, err := db.GetChunk(ctx, chunkID, path, repoID)
if err == nil {
chunkContext += chunk
} else if !errors.Is(err, pgx.ErrNoRows) {
return "", err
} else {
break
}
}
return chunkContext, nil
}
func FromConfig(ctx context.Context, cfg *config.Configuration) (*Database, error) {
migFS, err := iofs.New(migrations, "migrations")
if err != nil {
return nil, err
}
mig, err := migrate.NewWithSourceInstance("iofs", migFS, cfg.Database.ConnString)
if err != nil {
return nil, err
}
if err := mig.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
return nil, fmt.Errorf("problem performing database migrations: %w", err)
}
pool, err := pgxpool.New(ctx, cfg.Database.ConnString)
if err != nil {
return nil, err
}
return &Database{pool: pool}, nil
}
func FromContext(ctx context.Context) *Database {
return ctx.Value(contextKeyDB).(*Database)
}
func WrapContext(ctx context.Context, cfg *Database) context.Context {
return context.WithValue(ctx, contextKeyDB, cfg)
}