Checks to make sure repo is indexed before generating code. Don't generate tests for changes to tests. Remove unused code. Fix bootstrapping issue with langchaingo tables.
167 lines
4.2 KiB
Go
167 lines
4.2 KiB
Go
package database
|
|
|
|
import (
|
|
"ai-code-assistant/pkg/config"
|
|
"context"
|
|
"embed"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/go-git/go-billy/v5/osfs"
|
|
"github.com/go-git/go-git/v5"
|
|
"github.com/go-git/go-git/v5/plumbing/cache"
|
|
"github.com/go-git/go-git/v5/storage/filesystem"
|
|
"github.com/golang-migrate/migrate/v4"
|
|
_ "github.com/golang-migrate/migrate/v4/database/postgres"
|
|
"github.com/golang-migrate/migrate/v4/source/iofs"
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
_ "github.com/lib/pq"
|
|
"path/filepath"
|
|
"strconv"
|
|
)
|
|
|
|
//go:embed migrations
|
|
var migrations embed.FS
|
|
|
|
type contextKey string
|
|
|
|
const (
|
|
contextKeyDB contextKey = "db"
|
|
)
|
|
|
|
func preparedStatements() map[string]string {
|
|
return map[string]string{
|
|
"get_repo": `SELECT repo_id FROM repos WHERE repo_hash = $1 AND repo_path = $2`,
|
|
"insert_repo": `INSERT INTO repos (repo_id, repo_hash, repo_path) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING`,
|
|
"get_chunk": `SELECT document FROM langchain_pg_embedding WHERE JSON_EXTRACT_PATH_TEXT(cmetadata, 'chunk_id')=$1 AND JSON_EXTRACT_PATH_TEXT(cmetadata, 'file_path')=$2 AND JSON_EXTRACT_PATH_TEXT(cmetadata, 'repo_id')=$3`,
|
|
"clear_chunks_for_repo": `DELETE FROM langchain_pg_embedding WHERE JSON_EXTRACT_PATH_TEXT(cmetadata, 'repo_id')=$1`,
|
|
}
|
|
}
|
|
|
|
type Database struct {
|
|
pool *pgxpool.Pool
|
|
}
|
|
|
|
func (db *Database) DB(ctx context.Context) (*pgxpool.Conn, error) {
|
|
conn, err := db.pool.Acquire(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for nm, query := range preparedStatements() {
|
|
if _, err := conn.Conn().Prepare(ctx, nm, query); err != nil {
|
|
return nil, fmt.Errorf("problem preparing statement %s: %w", nm, err)
|
|
}
|
|
}
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
func (db *Database) UpsertRepo(ctx context.Context, repoPath string) (string, bool, error) {
|
|
gitPath := osfs.New(filepath.Join(repoPath, ".git"))
|
|
|
|
gitRepo, err := git.Open(filesystem.NewStorage(gitPath, cache.NewObjectLRUDefault()), gitPath)
|
|
if err != nil {
|
|
return "", false, err
|
|
}
|
|
|
|
headRef, err := gitRepo.Head()
|
|
if err != nil {
|
|
return "", false, err
|
|
}
|
|
|
|
conn, err := db.DB(ctx)
|
|
if err != nil {
|
|
return "", false, err
|
|
}
|
|
defer conn.Release()
|
|
|
|
var id string
|
|
|
|
if err := conn.QueryRow(ctx, "get_repo", headRef.Hash().String(), repoPath).Scan(&id); err == nil {
|
|
return id, true, nil
|
|
} else if !errors.Is(err, pgx.ErrNoRows) {
|
|
return "", false, err
|
|
}
|
|
|
|
id = uuid.NewString()
|
|
|
|
if _, err := conn.Exec(ctx, "insert_repo", id, headRef.Hash().String(), repoPath); err != nil {
|
|
return "", false, err
|
|
}
|
|
|
|
return id, false, nil
|
|
}
|
|
|
|
func (db *Database) GetChunk(ctx context.Context, chunkID int, path, repoID string) (string, error) {
|
|
conn, err := db.DB(ctx)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer conn.Release()
|
|
|
|
var chunk string
|
|
|
|
chunkIDStr := strconv.FormatInt(int64(chunkID), 10)
|
|
|
|
if err := conn.QueryRow(ctx, "get_chunk", chunkIDStr, path, repoID).Scan(&chunk); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return chunk, nil
|
|
}
|
|
|
|
func (db *Database) GetChunkContext(ctx context.Context, chunkID, distance int, path, repoID string) (string, error) {
|
|
minChunk := chunkID - distance
|
|
if minChunk < 0 {
|
|
minChunk = 0
|
|
}
|
|
|
|
chunkContext := ""
|
|
|
|
for chunkID := minChunk; chunkID < chunkID+(distance*2); chunkID++ {
|
|
chunk, err := db.GetChunk(ctx, chunkID, path, repoID)
|
|
if err == nil {
|
|
chunkContext += chunk
|
|
} else if !errors.Is(err, pgx.ErrNoRows) {
|
|
return "", err
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
|
|
return chunkContext, nil
|
|
}
|
|
|
|
func FromConfig(ctx context.Context, cfg *config.Configuration) (*Database, error) {
|
|
migFS, err := iofs.New(migrations, "migrations")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
mig, err := migrate.NewWithSourceInstance("iofs", migFS, cfg.Database.ConnString)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := mig.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
|
return nil, fmt.Errorf("problem performing database migrations: %w", err)
|
|
}
|
|
|
|
pool, err := pgxpool.New(ctx, cfg.Database.ConnString)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Database{pool: pool}, nil
|
|
}
|
|
|
|
func FromContext(ctx context.Context) *Database {
|
|
return ctx.Value(contextKeyDB).(*Database)
|
|
}
|
|
|
|
func WrapContext(ctx context.Context, cfg *Database) context.Context {
|
|
return context.WithValue(ctx, contextKeyDB, cfg)
|
|
}
|