package database import ( "ai-code-assistant/pkg/config" "context" "embed" "errors" "fmt" "github.com/go-git/go-billy/v5/osfs" "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing/cache" "github.com/go-git/go-git/v5/storage/filesystem" "github.com/golang-migrate/migrate/v4" _ "github.com/golang-migrate/migrate/v4/database/postgres" "github.com/golang-migrate/migrate/v4/source/iofs" "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" _ "github.com/lib/pq" "github.com/tmc/langchaingo/embeddings" "github.com/tmc/langchaingo/vectorstores" "github.com/tmc/langchaingo/vectorstores/pgvector" "path/filepath" "strconv" ) //go:embed migrations var migrations embed.FS type contextKey string const ( contextKeyDB contextKey = "db" ) func preparedStatements() map[string]string { return map[string]string{ "get_repo": `SELECT repo_id FROM repos WHERE repo_hash = $1 AND repo_path = $2`, "insert_repo": `INSERT INTO repos (repo_id, repo_hash, repo_path) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING`, "get_chunk": `SELECT document FROM langchain_pg_embedding WHERE JSON_EXTRACT_PATH_TEXT(cmetadata, 'chunk_id')=$1 AND JSON_EXTRACT_PATH_TEXT(cmetadata, 'file_path')=$2 AND JSON_EXTRACT_PATH_TEXT(cmetadata, 'repo_id')=$3`, "clear_chunks_for_repo": `DELETE FROM langchain_pg_embedding WHERE JSON_EXTRACT_PATH_TEXT(cmetadata, 'repo_id')=$1`, } } // Database is a helper that abstracts access to the database as well as handles migrations. type Database struct { pool *pgxpool.Pool } // db gets a references to the connection pool. func (db *Database) db(ctx context.Context) (*pgxpool.Conn, error) { conn, err := db.pool.Acquire(ctx) if err != nil { return nil, err } for nm, query := range preparedStatements() { if _, err := conn.Conn().Prepare(ctx, nm, query); err != nil { return nil, fmt.Errorf("problem preparing statement %s: %w", nm, err) } } return conn, nil } // UpsertRepo will ensure we have a repository reference for this particular repository path and HEAD reference. func (db *Database) UpsertRepo(ctx context.Context, repoPath string) (string, bool, error) { gitPath := osfs.New(filepath.Join(repoPath, ".git")) gitRepo, err := git.Open(filesystem.NewStorage(gitPath, cache.NewObjectLRUDefault()), gitPath) if err != nil { return "", false, err } headRef, err := gitRepo.Head() if err != nil { return "", false, err } conn, err := db.db(ctx) if err != nil { return "", false, err } defer conn.Release() var id string if err := conn.QueryRow(ctx, "get_repo", headRef.Hash().String(), repoPath).Scan(&id); err == nil { return id, true, nil } else if !errors.Is(err, pgx.ErrNoRows) { return "", false, err } id = uuid.NewString() if _, err := conn.Exec(ctx, "insert_repo", id, headRef.Hash().String(), repoPath); err != nil { return "", false, err } return id, false, nil } // GetChunk will get a specified chunk from the database. func (db *Database) GetChunk(ctx context.Context, chunkID int, path, repoID string) (string, error) { conn, err := db.db(ctx) if err != nil { return "", err } defer conn.Release() var chunk string chunkIDStr := strconv.FormatInt(int64(chunkID), 10) if err := conn.QueryRow(ctx, "get_chunk", chunkIDStr, path, repoID).Scan(&chunk); err != nil { return "", err } return chunk, nil } // ClearChunkIndex will clear all embeddings for a particular repo ID. func (db *Database) ClearChunkIndex(ctx context.Context, repoID string) error { conn, err := db.db(ctx) if err != nil { return err } defer conn.Release() if _, err := conn.Exec(ctx, "clear_chunks_for_repo", repoID); err != nil { return err } return nil } // GetVectorStore gets a vector store capable of storing and searching embeddings. func (db *Database) GetVectorStore(ctx context.Context, embedder embeddings.Embedder) (vectorstores.VectorStore, func(), error) { conn, err := db.db(ctx) if err != nil { return nil, nil, err } vectorStore, err := pgvector.New(ctx, pgvector.WithConn(conn), pgvector.WithEmbedder(embedder), pgvector.WithCollectionName("file_chunks"), ) if err != nil { return nil, nil, err } return vectorStore, func() { conn.Release() }, nil } // GetChunkContext will get a specified chunk and surrounding context based on some amount of distance. func (db *Database) GetChunkContext(ctx context.Context, chunkID, distance int, path, repoID string) (string, error) { minChunk := chunkID - distance if minChunk < 0 { minChunk = 0 } chunkContext := "" for chunkID := minChunk; chunkID < chunkID+(distance*2); chunkID++ { chunk, err := db.GetChunk(ctx, chunkID, path, repoID) if err == nil { chunkContext += chunk } else if !errors.Is(err, pgx.ErrNoRows) { return "", err } else { break } } return chunkContext, nil } // FromConfig generates a new Database from a passed in configuration. Used for bootstrapping. func FromConfig(ctx context.Context, cfg *config.Configuration) (*Database, error) { migFS, err := iofs.New(migrations, "migrations") if err != nil { return nil, err } mig, err := migrate.NewWithSourceInstance("iofs", migFS, cfg.Database.ConnString) if err != nil { return nil, err } if err := mig.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) { return nil, fmt.Errorf("problem performing database migrations: %w", err) } pool, err := pgxpool.New(ctx, cfg.Database.ConnString) if err != nil { return nil, err } return &Database{pool: pool}, nil } // FromContext will retrieve the Database from a Context where WrapContext was used to embed the database in the context. func FromContext(ctx context.Context) *Database { return ctx.Value(contextKeyDB).(*Database) } // WrapContext embeds a Database inside a Context so it can be passed to functions. func WrapContext(ctx context.Context, cfg *Database) context.Context { return context.WithValue(ctx, contextKeyDB, cfg) }