203 lines
5.1 KiB
Go
203 lines
5.1 KiB
Go
package database
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"log"
|
|
"reflect"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
_ "github.com/go-sql-driver/mysql" // Import for registration
|
|
|
|
"gitea.zokki.net/zokki/uni/web43-diary/context"
|
|
"gitea.zokki.net/zokki/uni/web43-diary/internal/config"
|
|
)
|
|
|
|
type SQLTable interface {
|
|
LoadForeignValues(*context.Context) error
|
|
}
|
|
|
|
type SQLPrimary interface {
|
|
GetPrimaryKeys() []string
|
|
}
|
|
|
|
var tablesToCreate = []*sqlTable{}
|
|
|
|
func NewDB() (*sql.DB, error) {
|
|
db, err := sql.Open("mysql", config.Config.DatabaseDsn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := db.Ping(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
db.SetConnMaxLifetime(time.Minute * 3)
|
|
db.SetConnMaxIdleTime(time.Minute * 3)
|
|
db.SetMaxOpenConns(1)
|
|
db.SetMaxIdleConns(1)
|
|
|
|
return db, nil
|
|
}
|
|
|
|
func AddCreateTableQueue(table any) {
|
|
tablesToCreate = append(tablesToCreate, getSQLTableFromInterface(table))
|
|
}
|
|
|
|
func CreateTablesFromQueue() {
|
|
db, err := NewDB()
|
|
if err != nil {
|
|
log.Fatal("[CreateTable] connect db: ", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
sort.Sort(ByTableName(tablesToCreate))
|
|
|
|
for _, table := range tablesToCreate {
|
|
sqlStatement := fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s` (", table.Name)
|
|
sqlForeignKeys := ""
|
|
|
|
for _, column := range table.Columns {
|
|
sqlStatement += fmt.Sprintf("%s %s, ", column.Name, column.Statement)
|
|
|
|
if column.ForeignKey != "" {
|
|
sqlForeignKeys += fmt.Sprintf("FOREIGN KEY (%s) REFERENCES %s, ", column.Name, column.ForeignKey)
|
|
}
|
|
}
|
|
|
|
if _, err := db.Exec(strings.TrimSuffix(sqlStatement+sqlForeignKeys+table.PrimaryKey, ", ") + ")"); err != nil {
|
|
log.Fatal(fmt.Sprintf("[CreateTable] create table `%s`: ", table.Name), err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func GetOne[TableVal SQLTable](ctx *context.Context, val TableVal) (TableVal, error) {
|
|
return GetOneWhere(ctx, QueryBuilderFromInterface(val))
|
|
}
|
|
|
|
func GetOneWhere[TableVal SQLTable](ctx *context.Context, queryBuilder *queryBuilder[TableVal]) (TableVal, error) {
|
|
var empty TableVal
|
|
|
|
sqlQuery, sqlValues := queryBuilder.Limit(1).BuildSelect()
|
|
sqlRows, err := ctx.DB.QueryContext(ctx, sqlQuery, sqlValues...)
|
|
if err != nil {
|
|
return empty, err
|
|
}
|
|
defer sqlRows.Close()
|
|
|
|
if !sqlRows.Next() {
|
|
return empty, fmt.Errorf("no data found for table '%s'", queryBuilder.table.Name)
|
|
}
|
|
|
|
columns, err := sqlRows.Columns()
|
|
if err != nil {
|
|
return empty, err
|
|
}
|
|
|
|
values := make([]any, len(columns))
|
|
columnPointers := make([]any, len(columns))
|
|
for i := range values {
|
|
columnPointers[i] = &values[i]
|
|
}
|
|
|
|
if err := sqlRows.Scan(columnPointers...); err != nil {
|
|
return empty, err
|
|
}
|
|
|
|
val := queryBuilder.object
|
|
deserialize(val, columns, values)
|
|
|
|
sqlRows.Close() // needs to be closed before using db again
|
|
err = val.LoadForeignValues(ctx)
|
|
if err != nil {
|
|
return empty, err
|
|
}
|
|
|
|
return val, nil
|
|
}
|
|
|
|
func GetAll[TableVal SQLTable](ctx *context.Context, val TableVal) ([]TableVal, error) {
|
|
return GetAllWhere(ctx, QueryBuilderFromInterface(val))
|
|
}
|
|
|
|
func GetAllWhere[TableVal SQLTable](ctx *context.Context, queryBuilder *queryBuilder[TableVal]) ([]TableVal, error) {
|
|
sqlQuery, sqlValues := queryBuilder.BuildSelect()
|
|
sqlRows, err := ctx.DB.QueryContext(ctx, sqlQuery, sqlValues...)
|
|
if err != nil {
|
|
log.Println("err", err, sqlQuery, sqlValues)
|
|
return nil, err
|
|
}
|
|
defer sqlRows.Close()
|
|
|
|
columns, err := sqlRows.Columns()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
values := make([]any, len(columns))
|
|
columnPointers := make([]any, len(columns))
|
|
for i := range values {
|
|
columnPointers[i] = &values[i]
|
|
}
|
|
|
|
tableRows := []TableVal{}
|
|
reflectType := reflect.TypeOf(queryBuilder.object).Elem()
|
|
for sqlRows.Next() {
|
|
if err := sqlRows.Scan(columnPointers...); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tableVal := reflect.New(reflectType).Interface().(TableVal)
|
|
deserialize(tableVal, columns, values)
|
|
|
|
tableRows = append(tableRows, tableVal)
|
|
}
|
|
|
|
sqlRows.Close() // needs to be closed before using db again
|
|
for _, row := range tableRows {
|
|
err = row.LoadForeignValues(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return tableRows, nil
|
|
}
|
|
|
|
func InsertInto(ctx *context.Context, table any) (int64, error) {
|
|
sqlTableName := getSQLTableName(table)
|
|
columns, tableValues := serialize(table)
|
|
if len(columns) < 1 {
|
|
return 0, nil
|
|
}
|
|
|
|
res, err := ctx.DB.ExecContext(ctx, fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s ?)", sqlTableName, strings.Join(columns, ", "), strings.Repeat("?, ", len(tableValues)-1)), tableValues...)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return res.LastInsertId()
|
|
}
|
|
|
|
func Update[TableVal SQLTable](ctx *context.Context, searchTable TableVal, newTable TableVal) error {
|
|
newTableColumns, newTableValues := serialize(newTable)
|
|
if len(newTableColumns) < 1 {
|
|
return nil
|
|
}
|
|
|
|
queryBuilder := QueryBuilderFromInterface(searchTable)
|
|
where, whereValues := queryBuilder.where.Build(true)
|
|
|
|
queryValues := append(newTableValues, whereValues...)
|
|
_, err := ctx.DB.ExecContext(ctx, fmt.Sprintf("UPDATE `%s` SET %s = ? %s", queryBuilder.table.Name, strings.Join(newTableColumns, " = ?, "), where), queryValues...)
|
|
return err
|
|
}
|
|
|
|
func Delete(ctx *context.Context, table SQLTable) error {
|
|
sqlQuery, sqlValues := QueryBuilderFromInterface(table).BuildDelete()
|
|
_, err := ctx.DB.ExecContext(ctx, sqlQuery, sqlValues...)
|
|
return err
|
|
}
|