package autopatch import ( "ai-code-assistant/pkg/config" "ai-code-assistant/pkg/database" "ai-code-assistant/pkg/indexer" "ai-code-assistant/pkg/llm" "bytes" "context" "errors" "fmt" "github.com/go-git/go-billy/v5/osfs" "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing/cache" "github.com/go-git/go-git/v5/storage/filesystem" "github.com/sergi/go-diff/diffmatchpatch" "github.com/urfave/cli/v3" "log/slog" "os" "path/filepath" "strings" ) // Command defines the autopatch command along with command line flags. func Command() *cli.Command { return &cli.Command{ Name: "auto-patch", Usage: "this command accepts a repository and a prompt and will generate a git commit attempting code modifications to satisfy the prompt.", Action: (&autoPatch{}).run, Flags: []cli.Flag{ &cli.StringFlag{ Name: "repo", Usage: "path to git repository", Required: true, }, &cli.StringFlag{ Name: "task", Usage: "task to perform, e.g. \"add a http route for a new health check endpoint\"", Required: true, }, &cli.BoolFlag{ Name: "execute", Usage: "actually execute the commit, otherwise just print the code to be committed", }, }, } } // autoPatch is a struct implementing the auto patcher. type autoPatch struct { llm *llm.LLM execute bool } // run gets executed when the command is run. func (a *autoPatch) run(ctx context.Context, cmd *cli.Command) error { llmRef := llm.FromContext(ctx) a.llm = llmRef a.execute = cmd.Bool("execute") // Make sure we're indexed. idx := indexer.New(ctx, cmd.String("repo"), config.FromContext(ctx).IndexChunkSize, false) if err := idx.Index(ctx); err != nil { return err } // Attempt to generate the commit. err := a.generateGitCommit(ctx, cmd.String("repo"), cmd.String("task")) if err != nil { return err } return nil } // generateGitCommit will generate the code patch, the unit test, and will stage and commit the changes with an // appropriate commit message. func (a *autoPatch) generateGitCommit(ctx context.Context, repoPath, prompt string) error { var affectedFiles []string fileName, newCode, err := a.generateCodePatch(ctx, repoPath, prompt) if err != nil { return err } affectedFiles = append(affectedFiles, fileName) // If we modified a test, we don't need to generate a test. if !strings.HasSuffix(fileName, "_test.go") { testFile, err := a.generateUnitTest(ctx, prompt, fileName, newCode) if err != nil { return err } affectedFiles = append(affectedFiles, testFile) } if err := a.commit(ctx, prompt, repoPath, affectedFiles...); err != nil { return err } return nil } // commit stages the changed files and commits them with a commit message. func (a *autoPatch) commit(ctx context.Context, prompt, repoPath string, files ...string) error { gitPath := osfs.New(filepath.Join(repoPath, ".git")) gitRepo, err := git.Open(filesystem.NewStorage(gitPath, cache.NewObjectLRUDefault()), osfs.New(repoPath)) if err != nil { return err } workTree, err := gitRepo.Worktree() if err != nil { return err } for _, file := range files { // Relative paths. file = strings.TrimPrefix(file, repoPath) file = strings.TrimPrefix(file, "/") if _, err := workTree.Add(file); err != nil { return err } } genPrompt, err := llm.GetPrompt("generate_commitmsg", map[string]any{ "Prompt": prompt, "Files": files, }) if err != nil { return err } rsp, err := a.llm.ChatPrompt(ctx, genPrompt) if err != nil { return err } if a.execute { if _, err := workTree.Commit(rsp, &git.CommitOptions{}); err != nil { return err } slog.Info("committed changes to git repo", "repo", repoPath) } else { fmt.Printf("Commit Message:\n%s\n", rsp) } return nil } // generateCodePatch generates an appropriate code patch given a repository and prompt. func (a *autoPatch) generateCodePatch(ctx context.Context, repoPath, prompt string) (string, string, error) { db := database.FromContext(ctx) cfg := config.FromContext(ctx) repoID, _, err := db.UpsertRepo(ctx, repoPath) if err != nil { return "", "", err } relDocs := llm.NewGetRelevantDocs(db, a.llm, repoID, cfg.RelevantDocs) chunks, err := relDocs.GetRelevantFileChunks(ctx, prompt) if err != nil { return "", "", err } else if len(chunks) == 0 { return "", "", errors.New("no relevant chunks found") } chunk := chunks[0] slog.Info("found most relevant file chunk", "file", chunk.Name, "start", chunk.Start, "end", chunk.End, "score", chunk.Score, "id", chunk.ChunkID) chunkContext, err := db.GetChunkContext(ctx, chunks[0].ChunkID, 1, chunks[0].Name, repoID) if err != nil { return "", "", err } genPrompt, err := llm.GetPrompt("generate_patch", map[string]string{ "Prompt": prompt, "Context": chunkContext, }) if err != nil { return "", "", err } codeBlock, err := a.generateCode(ctx, genPrompt) if err != nil { return "", "", err } fileName := chunks[0].Name originalFile, err := os.ReadFile(fileName) if err != nil { return "", "", err } dmp := diffmatchpatch.New() diffs := dmp.DiffMain(string(originalFile), codeBlock, true) diffs = cleanDiffs(diffs) if a.execute { slog.Info("applying generated patch to file", "file", fileName) if err := patchFile(fileName, diffs); err != nil { return "", "", err } } else { fmt.Printf("Code block:\n%s\n", codeBlock) fmt.Printf("File to patch: %s\n", fileName) fmt.Println(dmp.DiffPrettyText(diffs)) } return fileName, codeBlock, err } // generateUnitTest passes the new code and a prompt to the LLM to generate an appropriate unit test. It will add the // new test to the bottom of an existing test file, or generate a new one if no unit test file already exists. func (a *autoPatch) generateUnitTest(ctx context.Context, prompt, fileName, newCode string) (string, error) { // Check to see if a test file for this already exists. testFileExists := false testFile := strings.ReplaceAll(fileName, ".go", "_test.go") if _, err := os.Stat(testFile); err == nil { testFileExists = true } genPrompt, err := llm.GetPrompt("generate_unittest", map[string]any{ "Prompt": prompt, "Context": newCode, "TestFileExists": testFileExists, }) if err != nil { return "", err } codeBlock, err := a.generateCode(ctx, genPrompt) if err != nil { return "", err } if a.execute { slog.Info("applying generated unit test to file", "file", testFile) if testFileExists { fp, err := os.OpenFile(testFile, os.O_APPEND|os.O_WRONLY, 0644) if err != nil { return "", err } defer fp.Close() if _, err := fp.WriteString("\n" + codeBlock); err != nil { return "", err } } else { fp, err := os.Open(testFile) if err != nil { return "", err } defer fp.Close() if _, err := fp.WriteString(codeBlock); err != nil { return "", err } } } else { fmt.Printf("Unit Test Code block for %s:\n%s\n", testFile, codeBlock) } return testFile, nil } // generateCode takes a prompt and tries to extract code from it. The prompt should try to get the LLM to structure the // code as a single yaml chunk. func (a *autoPatch) generateCode(ctx context.Context, prompt string) (string, error) { rsp, err := a.llm.CodePrompt(ctx, prompt) if err != nil { return "", err } startIdx := strings.Index(rsp, "```") endIdx := strings.LastIndex(rsp, "```") if startIdx == -1 || endIdx == -1 || startIdx >= endIdx { return "", fmt.Errorf("unable to find code block in response: %s", rsp) } codeBlock := rsp[startIdx+3 : endIdx] if strings.HasPrefix(codeBlock, "go") { codeBlock = codeBlock[2:] } return codeBlock, nil } // patchFile attempts to write the merged diff to the original file so it can be staged and committed. func patchFile(fileName string, diffs []diffmatchpatch.Diff) error { var buff bytes.Buffer for _, diff := range diffs { text := diff.Text switch diff.Type { case diffmatchpatch.DiffInsert: _, _ = buff.WriteString(text) case diffmatchpatch.DiffEqual: _, _ = buff.WriteString(text) } } if err := os.WriteFile(fileName, buff.Bytes(), 0644); err != nil { return err } return nil } // cleanDiffs will ignore any deletions at the beginning or end of the context since the LLM may trim these off. func cleanDiffs(diffs []diffmatchpatch.Diff) []diffmatchpatch.Diff { startIdx := 0 endIdx := len(diffs) for idx, diff := range diffs { if diff.Type == diffmatchpatch.DiffDelete { startIdx = idx + 1 } else { break } } for idx := len(diffs) - 1; idx >= 0; idx-- { if diffs[idx].Type == diffmatchpatch.DiffDelete { endIdx = idx - 1 } else { break } } return diffs[startIdx:endIdx] }