diff --git a/cmd/autopatch/autopatch.go b/cmd/autopatch/autopatch.go index 752f28a..f8b1cde 100644 --- a/cmd/autopatch/autopatch.go +++ b/cmd/autopatch/autopatch.go @@ -3,6 +3,7 @@ package autopatch import ( "ai-code-assistant/pkg/config" "ai-code-assistant/pkg/database" + "ai-code-assistant/pkg/indexer" "ai-code-assistant/pkg/llm" "bytes" "context" @@ -48,6 +49,13 @@ func (a *agent) run(ctx context.Context, cmd *cli.Command) error { llmRef := llm.FromContext(ctx) a.llm = llmRef + // Make sure we're indexed. + idx := indexer.New(ctx, cmd.String("repo"), config.FromContext(ctx).IndexChunkSize, false) + if err := idx.Index(ctx); err != nil { + return err + } + + // Attempt to generate the commit. err := a.generateGitCommit(ctx, cmd.String("repo"), cmd.String("task")) if err != nil { return err @@ -57,20 +65,26 @@ func (a *agent) run(ctx context.Context, cmd *cli.Command) error { } func (a *agent) generateGitCommit(ctx context.Context, repoPath, prompt string) error { + var affectedFiles []string fileName, newCode, err := a.generateCodePatch(ctx, repoPath, prompt) if err != nil { return err } - testFile, err := a.generateUnitTest(ctx, prompt, fileName, newCode) - if err != nil { - return err + affectedFiles = append(affectedFiles, fileName) + + // If we modified a test, we don't need to generate a test. + if !strings.HasSuffix(fileName, "_test.go") { + testFile, err := a.generateUnitTest(ctx, prompt, fileName, newCode) + if err != nil { + return err + } + + affectedFiles = append(affectedFiles, testFile) } - // fileName, testFile := "/home/mpowers/Projects/simple-go-server/main.go", "/home/mpowers/Projects/simple-go-server/main_test.go" - - if err := a.commit(ctx, prompt, repoPath, fileName, testFile); err != nil { + if err := a.commit(ctx, prompt, repoPath, affectedFiles...); err != nil { return err } @@ -126,7 +140,7 @@ func (a *agent) generateCodePatch(ctx context.Context, repoPath, prompt string) db := database.FromContext(ctx) cfg := config.FromContext(ctx) - repoID, err := db.RepoIDFromPath(ctx, repoPath) + repoID, _, err := db.UpsertRepo(ctx, repoPath) if err != nil { return "", "", err } diff --git a/cmd/chunks/chunks.go b/cmd/chunks/chunks.go index 3ee9c58..9131659 100644 --- a/cmd/chunks/chunks.go +++ b/cmd/chunks/chunks.go @@ -40,7 +40,7 @@ func (c *chunks) run(ctx context.Context, cmd *cli.Command) error { db := database.FromContext(ctx) llmRef := llm.FromContext(ctx) - repoID, err := db.RepoIDFromPath(ctx, cmd.String("repo")) + repoID, _, err := db.UpsertRepo(ctx, cmd.String("repo")) if err != nil { return err } @@ -51,10 +51,6 @@ func (c *chunks) run(ctx context.Context, cmd *cli.Command) error { return err } - if err := relDocs.RankChunks(ctx, cmd.String("query"), chunks); err != nil { - return err - } - for _, chunk := range chunks { slog.Info("found relevant chunk", "name", chunk.Name, "start", chunk.Start, "end", chunk.End, "score", chunk.Score, "id", chunk.ChunkID) } diff --git a/cmd/indexer/indexer.go b/cmd/indexer/indexer.go index 32faa2f..04e2d1b 100644 --- a/cmd/indexer/indexer.go +++ b/cmd/indexer/indexer.go @@ -1,28 +1,16 @@ package indexer import ( - "ai-code-assistant/pkg/database" - "ai-code-assistant/pkg/llm" + "ai-code-assistant/pkg/indexer" "context" - "github.com/go-git/go-billy/v5/osfs" - "github.com/go-git/go-git/v5" - "github.com/go-git/go-git/v5/plumbing/cache" - "github.com/go-git/go-git/v5/storage/filesystem" - "github.com/google/uuid" - "github.com/tmc/langchaingo/schema" - "github.com/tmc/langchaingo/vectorstores/pgvector" "github.com/urfave/cli/v3" - "log/slog" - "os" - "path/filepath" - "strconv" ) func Command() *cli.Command { return &cli.Command{ Name: "indexer", Usage: "this command will index a local git repository to build context for the llm", - Action: (&indexer{}).run, + Action: run, Flags: []cli.Flag{ &cli.StringFlag{ Name: "repo", @@ -34,170 +22,20 @@ func Command() *cli.Command { Usage: "number of bytes to chunk files into, should be roughly 4x the number of tokens", Value: 512 * 4, }, + &cli.BoolFlag{ + Name: "force", + Usage: "force re-indexing of the repository", + }, }, } } -type indexer struct { - db *database.Database - llm *llm.LLM - repoPath string - repoID string - chunkSize int -} +func run(ctx context.Context, cmd *cli.Command) error { + idx := indexer.New(ctx, cmd.String("repo"), int(cmd.Int("chunk-size")), cmd.Bool("force")) -func (idx *indexer) run(ctx context.Context, cmd *cli.Command) error { - idx.db = database.FromContext(ctx) - idx.repoPath = cmd.String("repo") - idx.chunkSize = int(cmd.Int("chunk-size")) - idx.llm = llm.FromContext(ctx) - - if err := idx.upsertRepo(ctx); err != nil { - return err - } - - if err := idx.generateFileChunks(ctx); err != nil { + if err := idx.Index(ctx); err != nil { return err } return nil } - -func (idx *indexer) upsertRepo(ctx context.Context) error { - gitPath := osfs.New(filepath.Join(idx.repoPath, ".git")) - - gitRepo, err := git.Open(filesystem.NewStorage(gitPath, cache.NewObjectLRUDefault()), gitPath) - if err != nil { - return err - } - - headRef, err := gitRepo.Head() - if err != nil { - return err - } - - conn, err := idx.db.DB(ctx) - if err != nil { - return err - } - defer conn.Release() - - id := uuid.NewString() - - if _, err := conn.Exec(ctx, "insert_repo", id, headRef.Hash().String(), idx.repoPath); err != nil { - return err - } - - idx.repoID = id - - return nil -} - -func crawlFiles(ctx context.Context, path string, cb func(ctx context.Context, filePath string) error) error { - pathFiles, err := os.ReadDir(path) - if err != nil { - return err - } - - for _, file := range pathFiles { - filePath := filepath.Join(path, file.Name()) - - if file.IsDir() { - if err := crawlFiles(ctx, filePath, cb); err != nil { - return err - } - } else { - if err := cb(ctx, filePath); err != nil { - return err - } - } - } - - return nil -} - -func (idx *indexer) generateFileChunks(ctx context.Context) error { - conn, err := idx.db.DB(ctx) - if err != nil { - return err - } - defer conn.Release() - - vectorStore, err := pgvector.New(ctx, - pgvector.WithConn(conn), - pgvector.WithEmbedder(idx.llm.Embedder()), - pgvector.WithCollectionName("file_chunks"), - ) - if err != nil { - return err - } - - allowedExtensions := []string{".go"} - - return crawlFiles(ctx, idx.repoPath, func(ctx context.Context, filePath string) error { - chunkID := 0 - - return chunkFile(ctx, filePath, idx.chunkSize, func(chunk []byte, start, end uint64) error { - shouldIndex := false - for _, ext := range allowedExtensions { - if filepath.Ext(filePath) == ext { - shouldIndex = true - - break - } - } - - if !shouldIndex { - return nil - } - - slog.Info("indexing file", "chunk_id", chunkID, "chunk_size", len(chunk), "file_name", filePath) - - docs := []schema.Document{{ - PageContent: string(chunk), - Metadata: map[string]any{ - "type": "file_chunk", - "file_path": filePath, - "chunk_id": strconv.FormatInt(int64(chunkID), 10), - "start": strconv.FormatUint(start, 10), - "end": strconv.FormatUint(end, 10), - "repo_id": idx.repoID, - }, - }} - - if _, err := vectorStore.AddDocuments(ctx, docs); err != nil { - return err - } - - chunkID++ - return nil - }) - }) -} - -// chunkFile will take a file and return it in chunks that are suitable size to be embedded. -// This is a very simple algorithm right now, it would be better to use a lexer to identify good parts of the AST to -// split on. We could also implement a reference graph to find the most relevant files based on the relationships -// between files. -func chunkFile(_ context.Context, filePath string, maxBytes int, chunkCb func(chunk []byte, start, end uint64) error) error { - fileBytes, err := os.ReadFile(filePath) - if err != nil { - return err - } - - pos := 0 - for pos < len(fileBytes) { - nextChunkSize := maxBytes - if pos+maxBytes > len(fileBytes) { - nextChunkSize = len(fileBytes) - pos - } - - if err := chunkCb(fileBytes[pos:pos+nextChunkSize], uint64(pos), uint64(pos+nextChunkSize)); err != nil { - return err - } - - pos += maxBytes - } - - return nil -} diff --git a/cmd/main.go b/cmd/main.go index e3aaa06..d0a3474 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -73,6 +73,14 @@ func readConfig(ctx context.Context, cmd *cli.Command) (context.Context, error) return nil, fmt.Errorf("problem parsing config: %w", err) } + if cfg.IndexChunkSize == 0 { + cfg.IndexChunkSize = 512 * 4 + } + + if cfg.RelevantDocs == 0 { + cfg.RelevantDocs = 5 + } + return config.WrapContext(ctx, cfg), nil } @@ -84,7 +92,10 @@ func initLogging(ctx context.Context, _ *cli.Command) (context.Context, error) { return nil, err } - slog.SetLogLoggerLevel(lvl) + handler := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ + Level: lvl, + })) + slog.SetDefault(handler) return ctx, nil } diff --git a/pkg/config/config.go b/pkg/config/config.go index 4534b49..40f7625 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -22,10 +22,11 @@ type Configuration struct { Logging struct { Level string `yaml:"level"` } `yaml:"logging"` - Embedding LLMConfig `yaml:"embedding"` - Code LLMConfig `yaml:"code"` - Chat LLMConfig `yaml:"chat"` - RelevantDocs int `yaml:"relevant_docs"` + Embedding LLMConfig `yaml:"embedding"` + Code LLMConfig `yaml:"code"` + Chat LLMConfig `yaml:"chat"` + RelevantDocs int `yaml:"relevant_docs"` + IndexChunkSize int `yaml:"index_chunk_size"` } type LLMConfig struct { diff --git a/pkg/database/database.go b/pkg/database/database.go index 99f1054..dd9b5fb 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -6,12 +6,18 @@ import ( "embed" "errors" "fmt" + "github.com/go-git/go-billy/v5/osfs" + "github.com/go-git/go-git/v5" + "github.com/go-git/go-git/v5/plumbing/cache" + "github.com/go-git/go-git/v5/storage/filesystem" "github.com/golang-migrate/migrate/v4" _ "github.com/golang-migrate/migrate/v4/database/postgres" "github.com/golang-migrate/migrate/v4/source/iofs" + "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" _ "github.com/lib/pq" + "path/filepath" "strconv" ) @@ -26,9 +32,10 @@ const ( func preparedStatements() map[string]string { return map[string]string{ - "insert_repo": `INSERT INTO repos (repo_id, repo_hash, repo_path) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING`, - "repo_from_path": `SELECT repo_id FROM repos WHERE repo_path = $1`, - "get_chunk": `SELECT document FROM langchain_pg_embedding WHERE JSON_EXTRACT_PATH_TEXT(cmetadata, 'chunk_id')=$1 AND JSON_EXTRACT_PATH_TEXT(cmetadata, 'file_path')=$2 AND JSON_EXTRACT_PATH_TEXT(cmetadata, 'repo_id')=$3`, + "get_repo": `SELECT repo_id FROM repos WHERE repo_hash = $1 AND repo_path = $2`, + "insert_repo": `INSERT INTO repos (repo_id, repo_hash, repo_path) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING`, + "get_chunk": `SELECT document FROM langchain_pg_embedding WHERE JSON_EXTRACT_PATH_TEXT(cmetadata, 'chunk_id')=$1 AND JSON_EXTRACT_PATH_TEXT(cmetadata, 'file_path')=$2 AND JSON_EXTRACT_PATH_TEXT(cmetadata, 'repo_id')=$3`, + "clear_chunks_for_repo": `DELETE FROM langchain_pg_embedding WHERE JSON_EXTRACT_PATH_TEXT(cmetadata, 'repo_id')=$1`, } } @@ -51,20 +58,40 @@ func (db *Database) DB(ctx context.Context) (*pgxpool.Conn, error) { return conn, nil } -func (db *Database) RepoIDFromPath(ctx context.Context, path string) (string, error) { +func (db *Database) UpsertRepo(ctx context.Context, repoPath string) (string, bool, error) { + gitPath := osfs.New(filepath.Join(repoPath, ".git")) + + gitRepo, err := git.Open(filesystem.NewStorage(gitPath, cache.NewObjectLRUDefault()), gitPath) + if err != nil { + return "", false, err + } + + headRef, err := gitRepo.Head() + if err != nil { + return "", false, err + } + conn, err := db.DB(ctx) if err != nil { - return "", err + return "", false, err } defer conn.Release() - var repoID string + var id string - if err := conn.QueryRow(ctx, "repo_from_path", path).Scan(&repoID); err != nil { - return "", err + if err := conn.QueryRow(ctx, "get_repo", headRef.Hash().String(), repoPath).Scan(&id); err == nil { + return id, true, nil + } else if !errors.Is(err, pgx.ErrNoRows) { + return "", false, err } - return repoID, nil + id = uuid.NewString() + + if _, err := conn.Exec(ctx, "insert_repo", id, headRef.Hash().String(), repoPath); err != nil { + return "", false, err + } + + return id, false, nil } func (db *Database) GetChunk(ctx context.Context, chunkID int, path, repoID string) (string, error) { diff --git a/pkg/database/migrations/01_schema.up.sql b/pkg/database/migrations/01_schema.up.sql index 52f2996..56f08bf 100644 --- a/pkg/database/migrations/01_schema.up.sql +++ b/pkg/database/migrations/01_schema.up.sql @@ -4,3 +4,19 @@ CREATE TABLE IF NOT EXISTS repos ( repo_path TEXT NOT NULL, UNIQUE(repo_hash, repo_path) ); + +CREATE TABLE langchain_pg_collection ( + name VARCHAR UNIQUE, + cmetadata JSON, + uuid UUID PRIMARY KEY +); + +CREATE TABLE langchain_pg_embedding ( + collection_id UUID REFERENCES langchain_pg_collection ON DELETE CASCADE, + embedding VECTOR, + document VARCHAR, + cmetadata JSON, + uuid UUID PRIMARY KEY +); + +CREATE INDEX langchain_pg_embedding_collection_id ON langchain_pg_embedding (collection_id); \ No newline at end of file diff --git a/pkg/indexer/indexer.go b/pkg/indexer/indexer.go new file mode 100644 index 0000000..15f280d --- /dev/null +++ b/pkg/indexer/indexer.go @@ -0,0 +1,182 @@ +package indexer + +import ( + "ai-code-assistant/pkg/database" + "ai-code-assistant/pkg/llm" + "context" + "github.com/tmc/langchaingo/schema" + "github.com/tmc/langchaingo/vectorstores/pgvector" + "log/slog" + "os" + "path/filepath" + "strconv" +) + +type Indexer struct { + repoPath string + chunkSize int + force bool + db *database.Database + llm *llm.LLM + allowedExtensions []string +} + +func New(ctx context.Context, path string, chunkSize int, force bool) *Indexer { + return &Indexer{ + repoPath: path, + chunkSize: chunkSize, + force: force, + db: database.FromContext(ctx), + llm: llm.FromContext(ctx), + allowedExtensions: []string{ + ".go", + }, + } +} + +func (idx *Indexer) Index(ctx context.Context) error { + repoID, didUpdate, err := idx.db.UpsertRepo(ctx, idx.repoPath) + if err != nil { + return err + } + + if didUpdate && !idx.force { + slog.Info("repo already indexed, skipping") + + return nil + } else if didUpdate && idx.force { + slog.Info("repo already indexed, but force flag was set, cleaning and re-indexing") + if err := idx.cleanIndexForRepo(ctx, repoID); err != nil { + return err + } + } else if !didUpdate { + slog.Info("indexing new repository", "path", idx.repoPath, "repo_id", repoID) + } + + if err := idx.generateFileChunks(ctx, repoID); err != nil { + return err + } + + return nil +} + +func (idx *Indexer) cleanIndexForRepo(ctx context.Context, repoID string) error { + conn, err := idx.db.DB(ctx) + if err != nil { + return err + } + defer conn.Release() + + if _, err := conn.Exec(ctx, "clear_chunks_for_repo", repoID); err != nil { + return err + } + + return nil +} + +func crawlFiles(ctx context.Context, path string, cb func(ctx context.Context, filePath string) error) error { + pathFiles, err := os.ReadDir(path) + if err != nil { + return err + } + + for _, file := range pathFiles { + filePath := filepath.Join(path, file.Name()) + + if file.IsDir() { + if err := crawlFiles(ctx, filePath, cb); err != nil { + return err + } + } else { + if err := cb(ctx, filePath); err != nil { + return err + } + } + } + + return nil +} + +func (idx *Indexer) generateFileChunks(ctx context.Context, repoID string) error { + conn, err := idx.db.DB(ctx) + if err != nil { + return err + } + defer conn.Release() + + vectorStore, err := pgvector.New(ctx, + pgvector.WithConn(conn), + pgvector.WithEmbedder(idx.llm.Embedder()), + pgvector.WithCollectionName("file_chunks"), + ) + if err != nil { + return err + } + + return crawlFiles(ctx, idx.repoPath, func(ctx context.Context, filePath string) error { + chunkID := 0 + + return chunkFile(ctx, filePath, idx.chunkSize, func(chunk []byte, start, end uint64) error { + shouldIndex := false + for _, ext := range idx.allowedExtensions { + if filepath.Ext(filePath) == ext { + shouldIndex = true + + break + } + } + + if !shouldIndex { + return nil + } + + slog.Info("indexing file", "chunk_id", chunkID, "chunk_size", len(chunk), "file_name", filePath) + + docs := []schema.Document{{ + PageContent: string(chunk), + Metadata: map[string]any{ + "type": "file_chunk", + "file_path": filePath, + "chunk_id": strconv.FormatInt(int64(chunkID), 10), + "start": strconv.FormatUint(start, 10), + "end": strconv.FormatUint(end, 10), + "repo_id": repoID, + }, + }} + + if _, err := vectorStore.AddDocuments(ctx, docs); err != nil { + return err + } + + chunkID++ + return nil + }) + }) +} + +// chunkFile will take a file and return it in chunks that are suitable size to be embedded. +// This is a very simple algorithm right now, it would be better to use a lexer to identify good parts of the AST to +// split on. We could also implement a reference graph to find the most relevant files based on the relationships +// between files. +func chunkFile(_ context.Context, filePath string, maxBytes int, chunkCb func(chunk []byte, start, end uint64) error) error { + fileBytes, err := os.ReadFile(filePath) + if err != nil { + return err + } + + pos := 0 + for pos < len(fileBytes) { + nextChunkSize := maxBytes + if pos+maxBytes > len(fileBytes) { + nextChunkSize = len(fileBytes) - pos + } + + if err := chunkCb(fileBytes[pos:pos+nextChunkSize], uint64(pos), uint64(pos+nextChunkSize)); err != nil { + return err + } + + pos += maxBytes + } + + return nil +} diff --git a/pkg/llm/relevent_docs.go b/pkg/llm/relevent_docs.go index db4ebcc..7c66fa8 100644 --- a/pkg/llm/relevent_docs.go +++ b/pkg/llm/relevent_docs.go @@ -3,17 +3,11 @@ package llm import ( "ai-code-assistant/pkg/database" "context" - "fmt" - "github.com/cenkalti/backoff/v5" "github.com/tmc/langchaingo/callbacks" "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/vectorstores" "github.com/tmc/langchaingo/vectorstores/pgvector" - "log/slog" - "slices" "strconv" - "strings" - "time" ) type RelevantDocs struct { @@ -96,56 +90,3 @@ func (rd *RelevantDocs) GetRelevantFileChunks(ctx context.Context, query string) return chunks, nil } - -func (rd *RelevantDocs) RankChunks(ctx context.Context, query string, chunks []*FileChunkID) error { - var didErr error - - slices.SortFunc(chunks, func(a, b *FileChunkID) int { - if didErr != nil { - return 0 - } - - retr, err := rd.CompareChunks(ctx, query, a, b) - if err != nil { - didErr = err - } - - return retr - }) - - return didErr -} - -func (rd *RelevantDocs) CompareChunks(ctx context.Context, query string, chunk1, chunk2 *FileChunkID) (int, error) { - slog.Info("comparing chunks", "chunk_1_name", chunk1.Name, "chunk_1_id", chunk1.ChunkID, "chunk_2_name", chunk2.Name, "chunk_2_id", chunk2.ChunkID) - - prompt := `Given the following two pieces of code pick the most relevant chunk to the task described below. Reply as a json object in the format {"chunk_id": ""}. Only reply in JSON. Do not include any explanation or code.` - - prompt += "\n\n" + query + "\n\n" - - // Now that we have candidates we need to compare them against each other to find the most appropriate place to - // inject them. - - prompt += "-- chunk_id: chunk_1 --\n" - prompt += chunk1.Doc.PageContent - - prompt += "-- chunk_id: chunk_2 --\n" - prompt += chunk2.Doc.PageContent - - op := func() (int, error) { - rsp, err := rd.llm.CodePrompt(ctx, prompt) - if err != nil { - return 0, err - } - - if strings.Contains(rsp, "chunk_1") { - return -1, nil - } else if strings.Contains(rsp, "chunk_2") { - return 1, nil - } - - return 0, fmt.Errorf("compare response didn't contain a chunk id: %s", rsp) - } - - return backoff.Retry(ctx, op, backoff.WithBackOff(backoff.NewConstantBackOff(10*time.Millisecond)), backoff.WithMaxTries(1)) -}