Code Cleanup and Quality of Life

Checks to make sure repo is indexed before generating code.
Don't generate tests for changes to tests.
Remove unused code.
Fix bootstrapping issue with langchaingo tables.
This commit is contained in:
2025-04-20 08:31:26 -04:00
parent 4b8b8132fd
commit 25f8cae8cb
9 changed files with 282 additions and 256 deletions

View File

@@ -3,6 +3,7 @@ package autopatch
import ( import (
"ai-code-assistant/pkg/config" "ai-code-assistant/pkg/config"
"ai-code-assistant/pkg/database" "ai-code-assistant/pkg/database"
"ai-code-assistant/pkg/indexer"
"ai-code-assistant/pkg/llm" "ai-code-assistant/pkg/llm"
"bytes" "bytes"
"context" "context"
@@ -48,6 +49,13 @@ func (a *agent) run(ctx context.Context, cmd *cli.Command) error {
llmRef := llm.FromContext(ctx) llmRef := llm.FromContext(ctx)
a.llm = llmRef a.llm = llmRef
// Make sure we're indexed.
idx := indexer.New(ctx, cmd.String("repo"), config.FromContext(ctx).IndexChunkSize, false)
if err := idx.Index(ctx); err != nil {
return err
}
// Attempt to generate the commit.
err := a.generateGitCommit(ctx, cmd.String("repo"), cmd.String("task")) err := a.generateGitCommit(ctx, cmd.String("repo"), cmd.String("task"))
if err != nil { if err != nil {
return err return err
@@ -57,20 +65,26 @@ func (a *agent) run(ctx context.Context, cmd *cli.Command) error {
} }
func (a *agent) generateGitCommit(ctx context.Context, repoPath, prompt string) error { func (a *agent) generateGitCommit(ctx context.Context, repoPath, prompt string) error {
var affectedFiles []string
fileName, newCode, err := a.generateCodePatch(ctx, repoPath, prompt) fileName, newCode, err := a.generateCodePatch(ctx, repoPath, prompt)
if err != nil { if err != nil {
return err return err
} }
testFile, err := a.generateUnitTest(ctx, prompt, fileName, newCode) affectedFiles = append(affectedFiles, fileName)
if err != nil {
return err // If we modified a test, we don't need to generate a test.
if !strings.HasSuffix(fileName, "_test.go") {
testFile, err := a.generateUnitTest(ctx, prompt, fileName, newCode)
if err != nil {
return err
}
affectedFiles = append(affectedFiles, testFile)
} }
// fileName, testFile := "/home/mpowers/Projects/simple-go-server/main.go", "/home/mpowers/Projects/simple-go-server/main_test.go" if err := a.commit(ctx, prompt, repoPath, affectedFiles...); err != nil {
if err := a.commit(ctx, prompt, repoPath, fileName, testFile); err != nil {
return err return err
} }
@@ -126,7 +140,7 @@ func (a *agent) generateCodePatch(ctx context.Context, repoPath, prompt string)
db := database.FromContext(ctx) db := database.FromContext(ctx)
cfg := config.FromContext(ctx) cfg := config.FromContext(ctx)
repoID, err := db.RepoIDFromPath(ctx, repoPath) repoID, _, err := db.UpsertRepo(ctx, repoPath)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }

View File

@@ -40,7 +40,7 @@ func (c *chunks) run(ctx context.Context, cmd *cli.Command) error {
db := database.FromContext(ctx) db := database.FromContext(ctx)
llmRef := llm.FromContext(ctx) llmRef := llm.FromContext(ctx)
repoID, err := db.RepoIDFromPath(ctx, cmd.String("repo")) repoID, _, err := db.UpsertRepo(ctx, cmd.String("repo"))
if err != nil { if err != nil {
return err return err
} }
@@ -51,10 +51,6 @@ func (c *chunks) run(ctx context.Context, cmd *cli.Command) error {
return err return err
} }
if err := relDocs.RankChunks(ctx, cmd.String("query"), chunks); err != nil {
return err
}
for _, chunk := range chunks { for _, chunk := range chunks {
slog.Info("found relevant chunk", "name", chunk.Name, "start", chunk.Start, "end", chunk.End, "score", chunk.Score, "id", chunk.ChunkID) slog.Info("found relevant chunk", "name", chunk.Name, "start", chunk.Start, "end", chunk.End, "score", chunk.Score, "id", chunk.ChunkID)
} }

View File

@@ -1,28 +1,16 @@
package indexer package indexer
import ( import (
"ai-code-assistant/pkg/database" "ai-code-assistant/pkg/indexer"
"ai-code-assistant/pkg/llm"
"context" "context"
"github.com/go-git/go-billy/v5/osfs"
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing/cache"
"github.com/go-git/go-git/v5/storage/filesystem"
"github.com/google/uuid"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/vectorstores/pgvector"
"github.com/urfave/cli/v3" "github.com/urfave/cli/v3"
"log/slog"
"os"
"path/filepath"
"strconv"
) )
func Command() *cli.Command { func Command() *cli.Command {
return &cli.Command{ return &cli.Command{
Name: "indexer", Name: "indexer",
Usage: "this command will index a local git repository to build context for the llm", Usage: "this command will index a local git repository to build context for the llm",
Action: (&indexer{}).run, Action: run,
Flags: []cli.Flag{ Flags: []cli.Flag{
&cli.StringFlag{ &cli.StringFlag{
Name: "repo", Name: "repo",
@@ -34,170 +22,20 @@ func Command() *cli.Command {
Usage: "number of bytes to chunk files into, should be roughly 4x the number of tokens", Usage: "number of bytes to chunk files into, should be roughly 4x the number of tokens",
Value: 512 * 4, Value: 512 * 4,
}, },
&cli.BoolFlag{
Name: "force",
Usage: "force re-indexing of the repository",
},
}, },
} }
} }
type indexer struct { func run(ctx context.Context, cmd *cli.Command) error {
db *database.Database idx := indexer.New(ctx, cmd.String("repo"), int(cmd.Int("chunk-size")), cmd.Bool("force"))
llm *llm.LLM
repoPath string
repoID string
chunkSize int
}
func (idx *indexer) run(ctx context.Context, cmd *cli.Command) error { if err := idx.Index(ctx); err != nil {
idx.db = database.FromContext(ctx)
idx.repoPath = cmd.String("repo")
idx.chunkSize = int(cmd.Int("chunk-size"))
idx.llm = llm.FromContext(ctx)
if err := idx.upsertRepo(ctx); err != nil {
return err
}
if err := idx.generateFileChunks(ctx); err != nil {
return err return err
} }
return nil return nil
} }
func (idx *indexer) upsertRepo(ctx context.Context) error {
gitPath := osfs.New(filepath.Join(idx.repoPath, ".git"))
gitRepo, err := git.Open(filesystem.NewStorage(gitPath, cache.NewObjectLRUDefault()), gitPath)
if err != nil {
return err
}
headRef, err := gitRepo.Head()
if err != nil {
return err
}
conn, err := idx.db.DB(ctx)
if err != nil {
return err
}
defer conn.Release()
id := uuid.NewString()
if _, err := conn.Exec(ctx, "insert_repo", id, headRef.Hash().String(), idx.repoPath); err != nil {
return err
}
idx.repoID = id
return nil
}
func crawlFiles(ctx context.Context, path string, cb func(ctx context.Context, filePath string) error) error {
pathFiles, err := os.ReadDir(path)
if err != nil {
return err
}
for _, file := range pathFiles {
filePath := filepath.Join(path, file.Name())
if file.IsDir() {
if err := crawlFiles(ctx, filePath, cb); err != nil {
return err
}
} else {
if err := cb(ctx, filePath); err != nil {
return err
}
}
}
return nil
}
func (idx *indexer) generateFileChunks(ctx context.Context) error {
conn, err := idx.db.DB(ctx)
if err != nil {
return err
}
defer conn.Release()
vectorStore, err := pgvector.New(ctx,
pgvector.WithConn(conn),
pgvector.WithEmbedder(idx.llm.Embedder()),
pgvector.WithCollectionName("file_chunks"),
)
if err != nil {
return err
}
allowedExtensions := []string{".go"}
return crawlFiles(ctx, idx.repoPath, func(ctx context.Context, filePath string) error {
chunkID := 0
return chunkFile(ctx, filePath, idx.chunkSize, func(chunk []byte, start, end uint64) error {
shouldIndex := false
for _, ext := range allowedExtensions {
if filepath.Ext(filePath) == ext {
shouldIndex = true
break
}
}
if !shouldIndex {
return nil
}
slog.Info("indexing file", "chunk_id", chunkID, "chunk_size", len(chunk), "file_name", filePath)
docs := []schema.Document{{
PageContent: string(chunk),
Metadata: map[string]any{
"type": "file_chunk",
"file_path": filePath,
"chunk_id": strconv.FormatInt(int64(chunkID), 10),
"start": strconv.FormatUint(start, 10),
"end": strconv.FormatUint(end, 10),
"repo_id": idx.repoID,
},
}}
if _, err := vectorStore.AddDocuments(ctx, docs); err != nil {
return err
}
chunkID++
return nil
})
})
}
// chunkFile will take a file and return it in chunks that are suitable size to be embedded.
// This is a very simple algorithm right now, it would be better to use a lexer to identify good parts of the AST to
// split on. We could also implement a reference graph to find the most relevant files based on the relationships
// between files.
func chunkFile(_ context.Context, filePath string, maxBytes int, chunkCb func(chunk []byte, start, end uint64) error) error {
fileBytes, err := os.ReadFile(filePath)
if err != nil {
return err
}
pos := 0
for pos < len(fileBytes) {
nextChunkSize := maxBytes
if pos+maxBytes > len(fileBytes) {
nextChunkSize = len(fileBytes) - pos
}
if err := chunkCb(fileBytes[pos:pos+nextChunkSize], uint64(pos), uint64(pos+nextChunkSize)); err != nil {
return err
}
pos += maxBytes
}
return nil
}

View File

@@ -73,6 +73,14 @@ func readConfig(ctx context.Context, cmd *cli.Command) (context.Context, error)
return nil, fmt.Errorf("problem parsing config: %w", err) return nil, fmt.Errorf("problem parsing config: %w", err)
} }
if cfg.IndexChunkSize == 0 {
cfg.IndexChunkSize = 512 * 4
}
if cfg.RelevantDocs == 0 {
cfg.RelevantDocs = 5
}
return config.WrapContext(ctx, cfg), nil return config.WrapContext(ctx, cfg), nil
} }
@@ -84,7 +92,10 @@ func initLogging(ctx context.Context, _ *cli.Command) (context.Context, error) {
return nil, err return nil, err
} }
slog.SetLogLoggerLevel(lvl) handler := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
Level: lvl,
}))
slog.SetDefault(handler)
return ctx, nil return ctx, nil
} }

View File

@@ -22,10 +22,11 @@ type Configuration struct {
Logging struct { Logging struct {
Level string `yaml:"level"` Level string `yaml:"level"`
} `yaml:"logging"` } `yaml:"logging"`
Embedding LLMConfig `yaml:"embedding"` Embedding LLMConfig `yaml:"embedding"`
Code LLMConfig `yaml:"code"` Code LLMConfig `yaml:"code"`
Chat LLMConfig `yaml:"chat"` Chat LLMConfig `yaml:"chat"`
RelevantDocs int `yaml:"relevant_docs"` RelevantDocs int `yaml:"relevant_docs"`
IndexChunkSize int `yaml:"index_chunk_size"`
} }
type LLMConfig struct { type LLMConfig struct {

View File

@@ -6,12 +6,18 @@ import (
"embed" "embed"
"errors" "errors"
"fmt" "fmt"
"github.com/go-git/go-billy/v5/osfs"
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing/cache"
"github.com/go-git/go-git/v5/storage/filesystem"
"github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/database/postgres"
"github.com/golang-migrate/migrate/v4/source/iofs" "github.com/golang-migrate/migrate/v4/source/iofs"
"github.com/google/uuid"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"path/filepath"
"strconv" "strconv"
) )
@@ -26,9 +32,10 @@ const (
func preparedStatements() map[string]string { func preparedStatements() map[string]string {
return map[string]string{ return map[string]string{
"insert_repo": `INSERT INTO repos (repo_id, repo_hash, repo_path) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING`, "get_repo": `SELECT repo_id FROM repos WHERE repo_hash = $1 AND repo_path = $2`,
"repo_from_path": `SELECT repo_id FROM repos WHERE repo_path = $1`, "insert_repo": `INSERT INTO repos (repo_id, repo_hash, repo_path) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING`,
"get_chunk": `SELECT document FROM langchain_pg_embedding WHERE JSON_EXTRACT_PATH_TEXT(cmetadata, 'chunk_id')=$1 AND JSON_EXTRACT_PATH_TEXT(cmetadata, 'file_path')=$2 AND JSON_EXTRACT_PATH_TEXT(cmetadata, 'repo_id')=$3`, "get_chunk": `SELECT document FROM langchain_pg_embedding WHERE JSON_EXTRACT_PATH_TEXT(cmetadata, 'chunk_id')=$1 AND JSON_EXTRACT_PATH_TEXT(cmetadata, 'file_path')=$2 AND JSON_EXTRACT_PATH_TEXT(cmetadata, 'repo_id')=$3`,
"clear_chunks_for_repo": `DELETE FROM langchain_pg_embedding WHERE JSON_EXTRACT_PATH_TEXT(cmetadata, 'repo_id')=$1`,
} }
} }
@@ -51,20 +58,40 @@ func (db *Database) DB(ctx context.Context) (*pgxpool.Conn, error) {
return conn, nil return conn, nil
} }
func (db *Database) RepoIDFromPath(ctx context.Context, path string) (string, error) { func (db *Database) UpsertRepo(ctx context.Context, repoPath string) (string, bool, error) {
gitPath := osfs.New(filepath.Join(repoPath, ".git"))
gitRepo, err := git.Open(filesystem.NewStorage(gitPath, cache.NewObjectLRUDefault()), gitPath)
if err != nil {
return "", false, err
}
headRef, err := gitRepo.Head()
if err != nil {
return "", false, err
}
conn, err := db.DB(ctx) conn, err := db.DB(ctx)
if err != nil { if err != nil {
return "", err return "", false, err
} }
defer conn.Release() defer conn.Release()
var repoID string var id string
if err := conn.QueryRow(ctx, "repo_from_path", path).Scan(&repoID); err != nil { if err := conn.QueryRow(ctx, "get_repo", headRef.Hash().String(), repoPath).Scan(&id); err == nil {
return "", err return id, true, nil
} else if !errors.Is(err, pgx.ErrNoRows) {
return "", false, err
} }
return repoID, nil id = uuid.NewString()
if _, err := conn.Exec(ctx, "insert_repo", id, headRef.Hash().String(), repoPath); err != nil {
return "", false, err
}
return id, false, nil
} }
func (db *Database) GetChunk(ctx context.Context, chunkID int, path, repoID string) (string, error) { func (db *Database) GetChunk(ctx context.Context, chunkID int, path, repoID string) (string, error) {

View File

@@ -4,3 +4,19 @@ CREATE TABLE IF NOT EXISTS repos (
repo_path TEXT NOT NULL, repo_path TEXT NOT NULL,
UNIQUE(repo_hash, repo_path) UNIQUE(repo_hash, repo_path)
); );
CREATE TABLE langchain_pg_collection (
name VARCHAR UNIQUE,
cmetadata JSON,
uuid UUID PRIMARY KEY
);
CREATE TABLE langchain_pg_embedding (
collection_id UUID REFERENCES langchain_pg_collection ON DELETE CASCADE,
embedding VECTOR,
document VARCHAR,
cmetadata JSON,
uuid UUID PRIMARY KEY
);
CREATE INDEX langchain_pg_embedding_collection_id ON langchain_pg_embedding (collection_id);

182
pkg/indexer/indexer.go Normal file
View File

@@ -0,0 +1,182 @@
package indexer
import (
"ai-code-assistant/pkg/database"
"ai-code-assistant/pkg/llm"
"context"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/vectorstores/pgvector"
"log/slog"
"os"
"path/filepath"
"strconv"
)
type Indexer struct {
repoPath string
chunkSize int
force bool
db *database.Database
llm *llm.LLM
allowedExtensions []string
}
func New(ctx context.Context, path string, chunkSize int, force bool) *Indexer {
return &Indexer{
repoPath: path,
chunkSize: chunkSize,
force: force,
db: database.FromContext(ctx),
llm: llm.FromContext(ctx),
allowedExtensions: []string{
".go",
},
}
}
func (idx *Indexer) Index(ctx context.Context) error {
repoID, didUpdate, err := idx.db.UpsertRepo(ctx, idx.repoPath)
if err != nil {
return err
}
if didUpdate && !idx.force {
slog.Info("repo already indexed, skipping")
return nil
} else if didUpdate && idx.force {
slog.Info("repo already indexed, but force flag was set, cleaning and re-indexing")
if err := idx.cleanIndexForRepo(ctx, repoID); err != nil {
return err
}
} else if !didUpdate {
slog.Info("indexing new repository", "path", idx.repoPath, "repo_id", repoID)
}
if err := idx.generateFileChunks(ctx, repoID); err != nil {
return err
}
return nil
}
func (idx *Indexer) cleanIndexForRepo(ctx context.Context, repoID string) error {
conn, err := idx.db.DB(ctx)
if err != nil {
return err
}
defer conn.Release()
if _, err := conn.Exec(ctx, "clear_chunks_for_repo", repoID); err != nil {
return err
}
return nil
}
func crawlFiles(ctx context.Context, path string, cb func(ctx context.Context, filePath string) error) error {
pathFiles, err := os.ReadDir(path)
if err != nil {
return err
}
for _, file := range pathFiles {
filePath := filepath.Join(path, file.Name())
if file.IsDir() {
if err := crawlFiles(ctx, filePath, cb); err != nil {
return err
}
} else {
if err := cb(ctx, filePath); err != nil {
return err
}
}
}
return nil
}
func (idx *Indexer) generateFileChunks(ctx context.Context, repoID string) error {
conn, err := idx.db.DB(ctx)
if err != nil {
return err
}
defer conn.Release()
vectorStore, err := pgvector.New(ctx,
pgvector.WithConn(conn),
pgvector.WithEmbedder(idx.llm.Embedder()),
pgvector.WithCollectionName("file_chunks"),
)
if err != nil {
return err
}
return crawlFiles(ctx, idx.repoPath, func(ctx context.Context, filePath string) error {
chunkID := 0
return chunkFile(ctx, filePath, idx.chunkSize, func(chunk []byte, start, end uint64) error {
shouldIndex := false
for _, ext := range idx.allowedExtensions {
if filepath.Ext(filePath) == ext {
shouldIndex = true
break
}
}
if !shouldIndex {
return nil
}
slog.Info("indexing file", "chunk_id", chunkID, "chunk_size", len(chunk), "file_name", filePath)
docs := []schema.Document{{
PageContent: string(chunk),
Metadata: map[string]any{
"type": "file_chunk",
"file_path": filePath,
"chunk_id": strconv.FormatInt(int64(chunkID), 10),
"start": strconv.FormatUint(start, 10),
"end": strconv.FormatUint(end, 10),
"repo_id": repoID,
},
}}
if _, err := vectorStore.AddDocuments(ctx, docs); err != nil {
return err
}
chunkID++
return nil
})
})
}
// chunkFile will take a file and return it in chunks that are suitable size to be embedded.
// This is a very simple algorithm right now, it would be better to use a lexer to identify good parts of the AST to
// split on. We could also implement a reference graph to find the most relevant files based on the relationships
// between files.
func chunkFile(_ context.Context, filePath string, maxBytes int, chunkCb func(chunk []byte, start, end uint64) error) error {
fileBytes, err := os.ReadFile(filePath)
if err != nil {
return err
}
pos := 0
for pos < len(fileBytes) {
nextChunkSize := maxBytes
if pos+maxBytes > len(fileBytes) {
nextChunkSize = len(fileBytes) - pos
}
if err := chunkCb(fileBytes[pos:pos+nextChunkSize], uint64(pos), uint64(pos+nextChunkSize)); err != nil {
return err
}
pos += maxBytes
}
return nil
}

View File

@@ -3,17 +3,11 @@ package llm
import ( import (
"ai-code-assistant/pkg/database" "ai-code-assistant/pkg/database"
"context" "context"
"fmt"
"github.com/cenkalti/backoff/v5"
"github.com/tmc/langchaingo/callbacks" "github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/vectorstores" "github.com/tmc/langchaingo/vectorstores"
"github.com/tmc/langchaingo/vectorstores/pgvector" "github.com/tmc/langchaingo/vectorstores/pgvector"
"log/slog"
"slices"
"strconv" "strconv"
"strings"
"time"
) )
type RelevantDocs struct { type RelevantDocs struct {
@@ -96,56 +90,3 @@ func (rd *RelevantDocs) GetRelevantFileChunks(ctx context.Context, query string)
return chunks, nil return chunks, nil
} }
func (rd *RelevantDocs) RankChunks(ctx context.Context, query string, chunks []*FileChunkID) error {
var didErr error
slices.SortFunc(chunks, func(a, b *FileChunkID) int {
if didErr != nil {
return 0
}
retr, err := rd.CompareChunks(ctx, query, a, b)
if err != nil {
didErr = err
}
return retr
})
return didErr
}
func (rd *RelevantDocs) CompareChunks(ctx context.Context, query string, chunk1, chunk2 *FileChunkID) (int, error) {
slog.Info("comparing chunks", "chunk_1_name", chunk1.Name, "chunk_1_id", chunk1.ChunkID, "chunk_2_name", chunk2.Name, "chunk_2_id", chunk2.ChunkID)
prompt := `Given the following two pieces of code pick the most relevant chunk to the task described below. Reply as a json object in the format {"chunk_id": "<chunk>"}. Only reply in JSON. Do not include any explanation or code.`
prompt += "\n\n" + query + "\n\n"
// Now that we have candidates we need to compare them against each other to find the most appropriate place to
// inject them.
prompt += "-- chunk_id: chunk_1 --\n"
prompt += chunk1.Doc.PageContent
prompt += "-- chunk_id: chunk_2 --\n"
prompt += chunk2.Doc.PageContent
op := func() (int, error) {
rsp, err := rd.llm.CodePrompt(ctx, prompt)
if err != nil {
return 0, err
}
if strings.Contains(rsp, "chunk_1") {
return -1, nil
} else if strings.Contains(rsp, "chunk_2") {
return 1, nil
}
return 0, fmt.Errorf("compare response didn't contain a chunk id: %s", rsp)
}
return backoff.Retry(ctx, op, backoff.WithBackOff(backoff.NewConstantBackOff(10*time.Millisecond)), backoff.WithMaxTries(1))
}