235 lines
8.4 KiB
C#
235 lines
8.4 KiB
C#
using System;
|
||
using System.Collections.Generic;
|
||
using System.IO;
|
||
using System.Net.Http;
|
||
using System.Text;
|
||
using System.Threading.Tasks;
|
||
using AIImages.Models;
|
||
using Newtonsoft.Json;
|
||
using Verse;
|
||
|
||
namespace AIImages.Services
|
||
{
|
||
/// <summary>
|
||
/// Сервис для работы с Stable Diffusion API (AUTOMATIC1111 WebUI)
|
||
/// </summary>
|
||
public class StableDiffusionApiService : IStableDiffusionApiService
|
||
{
|
||
private readonly HttpClient httpClient;
|
||
private readonly string saveFolderPath;
|
||
|
||
public StableDiffusionApiService(string savePath = "AIImages/Generated")
|
||
{
|
||
httpClient = new HttpClient { Timeout = TimeSpan.FromMinutes(5) };
|
||
|
||
// Определяем путь для сохранения
|
||
saveFolderPath = Path.Combine(GenFilePaths.SaveDataFolderPath, savePath);
|
||
|
||
// Создаем папку, если не существует
|
||
if (!Directory.Exists(saveFolderPath))
|
||
{
|
||
Directory.CreateDirectory(saveFolderPath);
|
||
}
|
||
}
|
||
|
||
public async Task<GenerationResult> GenerateImageAsync(GenerationRequest request)
|
||
{
|
||
try
|
||
{
|
||
Log.Message(
|
||
$"[AI Images] Starting image generation with prompt: {request.Prompt.Substring(0, Math.Min(50, request.Prompt.Length))}..."
|
||
);
|
||
|
||
// Формируем JSON запрос для AUTOMATIC1111 API
|
||
var apiRequest = new
|
||
{
|
||
prompt = request.Prompt,
|
||
negative_prompt = request.NegativePrompt,
|
||
steps = request.Steps,
|
||
cfg_scale = request.CfgScale,
|
||
width = request.Width,
|
||
height = request.Height,
|
||
sampler_name = request.Sampler,
|
||
seed = request.Seed,
|
||
save_images = false,
|
||
send_images = true,
|
||
};
|
||
|
||
string jsonRequest = JsonConvert.SerializeObject(apiRequest);
|
||
var content = new StringContent(jsonRequest, Encoding.UTF8, "application/json");
|
||
|
||
// Отправляем запрос
|
||
string endpoint = $"{request.Model}/sdapi/v1/txt2img";
|
||
HttpResponseMessage response = await httpClient.PostAsync(endpoint, content);
|
||
|
||
if (!response.IsSuccessStatusCode)
|
||
{
|
||
string errorContent = await response.Content.ReadAsStringAsync();
|
||
Log.Error(
|
||
$"[AI Images] API request failed: {response.StatusCode} - {errorContent}"
|
||
);
|
||
return GenerationResult.Failure($"API Error: {response.StatusCode}");
|
||
}
|
||
|
||
string jsonResponse = await response.Content.ReadAsStringAsync();
|
||
var apiResponse = JsonConvert.DeserializeObject<Txt2ImgResponse>(jsonResponse);
|
||
|
||
if (apiResponse?.images == null || apiResponse.images.Length == 0)
|
||
{
|
||
return GenerationResult.Failure("No images returned from API");
|
||
}
|
||
|
||
// Декодируем изображение из base64
|
||
byte[] imageData = Convert.FromBase64String(apiResponse.images[0]);
|
||
|
||
// Сохраняем изображение
|
||
string fileName = $"pawn_{DateTime.Now:yyyyMMdd_HHmmss}.png";
|
||
string fullPath = Path.Combine(saveFolderPath, fileName);
|
||
await File.WriteAllBytesAsync(fullPath, imageData);
|
||
|
||
Log.Message($"[AI Images] Image generated successfully and saved to: {fullPath}");
|
||
|
||
return GenerationResult.SuccessResult(imageData, fullPath, request);
|
||
}
|
||
catch (TaskCanceledException)
|
||
{
|
||
return GenerationResult.Failure("Request timeout. Generation took too long.");
|
||
}
|
||
catch (HttpRequestException ex)
|
||
{
|
||
Log.Error($"[AI Images] HTTP error: {ex.Message}");
|
||
return GenerationResult.Failure($"Connection error: {ex.Message}");
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
Log.Error($"[AI Images] Unexpected error: {ex.Message}\n{ex.StackTrace}");
|
||
return GenerationResult.Failure($"Error: {ex.Message}");
|
||
}
|
||
}
|
||
|
||
public async Task<bool> CheckApiAvailability(string apiEndpoint)
|
||
{
|
||
try
|
||
{
|
||
string endpoint = $"{apiEndpoint}/sdapi/v1/sd-models";
|
||
HttpResponseMessage response = await httpClient.GetAsync(endpoint);
|
||
return response.IsSuccessStatusCode;
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
Log.Warning($"[AI Images] API check failed: {ex.Message}");
|
||
return false;
|
||
}
|
||
}
|
||
|
||
public async Task<List<string>> GetAvailableModels(string apiEndpoint)
|
||
{
|
||
try
|
||
{
|
||
string endpoint = $"{apiEndpoint}/sdapi/v1/sd-models";
|
||
HttpResponseMessage response = await httpClient.GetAsync(endpoint);
|
||
|
||
if (!response.IsSuccessStatusCode)
|
||
return new List<string>();
|
||
|
||
string jsonResponse = await response.Content.ReadAsStringAsync();
|
||
var models = JsonConvert.DeserializeObject<List<SdModel>>(jsonResponse);
|
||
|
||
var modelNames = new List<string>();
|
||
if (models != null)
|
||
{
|
||
foreach (var model in models)
|
||
{
|
||
modelNames.Add(model.title ?? model.model_name);
|
||
}
|
||
}
|
||
|
||
Log.Message($"[AI Images] Found {modelNames.Count} models");
|
||
return modelNames;
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
Log.Error($"[AI Images] Failed to load models: {ex.Message}");
|
||
return new List<string>();
|
||
}
|
||
}
|
||
|
||
public async Task<List<string>> GetAvailableSamplers(string apiEndpoint)
|
||
{
|
||
try
|
||
{
|
||
string endpoint = $"{apiEndpoint}/sdapi/v1/samplers";
|
||
HttpResponseMessage response = await httpClient.GetAsync(endpoint);
|
||
|
||
if (!response.IsSuccessStatusCode)
|
||
return GetDefaultSamplers();
|
||
|
||
string jsonResponse = await response.Content.ReadAsStringAsync();
|
||
var samplers = JsonConvert.DeserializeObject<List<SdSampler>>(jsonResponse);
|
||
|
||
var samplerNames = new List<string>();
|
||
if (samplers != null)
|
||
{
|
||
foreach (var sampler in samplers)
|
||
{
|
||
samplerNames.Add(sampler.name);
|
||
}
|
||
}
|
||
|
||
Log.Message($"[AI Images] Found {samplerNames.Count} samplers");
|
||
return samplerNames;
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
Log.Warning($"[AI Images] Failed to load samplers: {ex.Message}");
|
||
return GetDefaultSamplers();
|
||
}
|
||
}
|
||
|
||
private List<string> GetDefaultSamplers()
|
||
{
|
||
return new List<string>
|
||
{
|
||
"Euler a",
|
||
"Euler",
|
||
"LMS",
|
||
"Heun",
|
||
"DPM2",
|
||
"DPM2 a",
|
||
"DPM++ 2S a",
|
||
"DPM++ 2M",
|
||
"DPM++ SDE",
|
||
"DPM fast",
|
||
"DPM adaptive",
|
||
"LMS Karras",
|
||
"DPM2 Karras",
|
||
"DPM2 a Karras",
|
||
"DPM++ 2S a Karras",
|
||
"DPM++ 2M Karras",
|
||
"DPM++ SDE Karras",
|
||
"DDIM",
|
||
"PLMS",
|
||
};
|
||
}
|
||
|
||
// Вспомогательные классы для десериализации JSON ответов
|
||
#pragma warning disable S3459, S1144 // Properties set by JSON deserializer
|
||
private sealed class Txt2ImgResponse
|
||
{
|
||
public string[] images { get; set; }
|
||
}
|
||
|
||
private sealed class SdModel
|
||
{
|
||
public string title { get; set; }
|
||
public string model_name { get; set; }
|
||
}
|
||
|
||
private sealed class SdSampler
|
||
{
|
||
public string name { get; set; }
|
||
}
|
||
#pragma warning restore S3459, S1144
|
||
}
|
||
}
|