378 lines
13 KiB
C#
378 lines
13 KiB
C#
using System;
|
||
using System.Collections.Generic;
|
||
using System.IO;
|
||
using System.Linq;
|
||
using System.Net.Http;
|
||
using System.Text;
|
||
using System.Threading;
|
||
using System.Threading.Tasks;
|
||
using AIImages.Models;
|
||
using Newtonsoft.Json;
|
||
using Verse;
|
||
|
||
namespace AIImages.Services
|
||
{
|
||
/// <summary>
|
||
/// Адаптер для Stable Diffusion API (AUTOMATIC1111 WebUI)
|
||
/// TODO: В будущем можно мигрировать на библиотеку StableDiffusionNet когда API будет полностью совместимо
|
||
/// </summary>
|
||
public class StableDiffusionNetAdapter : IStableDiffusionApiService, IDisposable
|
||
{
|
||
// Shared HttpClient для предотвращения socket exhaustion
|
||
// См: https://learn.microsoft.com/en-us/dotnet/fundamentals/networking/http/httpclient-guidelines
|
||
private static readonly HttpClient _sharedHttpClient = new HttpClient
|
||
{
|
||
Timeout = TimeSpan.FromMinutes(5),
|
||
};
|
||
|
||
private readonly string _apiEndpoint;
|
||
private readonly string _saveFolderPath;
|
||
private bool _disposed;
|
||
|
||
public StableDiffusionNetAdapter(string apiEndpoint, string savePath = "AIImages/Generated")
|
||
{
|
||
if (string.IsNullOrEmpty(apiEndpoint))
|
||
{
|
||
throw new ArgumentException(
|
||
"API endpoint cannot be null or empty",
|
||
nameof(apiEndpoint)
|
||
);
|
||
}
|
||
|
||
_apiEndpoint = apiEndpoint;
|
||
|
||
// Определяем путь для сохранения
|
||
_saveFolderPath = Path.Combine(GenFilePaths.SaveDataFolderPath, savePath);
|
||
|
||
// Создаем папку, если не существует
|
||
if (!Directory.Exists(_saveFolderPath))
|
||
{
|
||
Directory.CreateDirectory(_saveFolderPath);
|
||
}
|
||
|
||
Log.Message(
|
||
$"[AI Images] StableDiffusion adapter initialized with endpoint: {apiEndpoint}"
|
||
);
|
||
}
|
||
|
||
public async Task<GenerationResult> GenerateImageAsync(
|
||
GenerationRequest request,
|
||
CancellationToken cancellationToken = default
|
||
)
|
||
{
|
||
ThrowIfDisposed();
|
||
|
||
if (request == null)
|
||
{
|
||
return GenerationResult.Failure("Request cannot be null");
|
||
}
|
||
|
||
try
|
||
{
|
||
Log.Message(
|
||
$"[AI Images] Starting image generation with prompt: {request.Prompt.Substring(0, Math.Min(50, request.Prompt.Length))}..."
|
||
);
|
||
|
||
// Формируем JSON запрос для AUTOMATIC1111 API
|
||
var apiRequest = new
|
||
{
|
||
prompt = request.Prompt,
|
||
negative_prompt = request.NegativePrompt,
|
||
steps = request.Steps,
|
||
cfg_scale = request.CfgScale,
|
||
width = request.Width,
|
||
height = request.Height,
|
||
sampler_name = request.Sampler,
|
||
scheduler = request.Scheduler,
|
||
seed = request.Seed,
|
||
save_images = false,
|
||
send_images = true,
|
||
};
|
||
|
||
string jsonRequest = JsonConvert.SerializeObject(apiRequest);
|
||
var content = new StringContent(jsonRequest, Encoding.UTF8, "application/json");
|
||
|
||
// Отправляем запрос с поддержкой cancellation
|
||
string endpoint = $"{_apiEndpoint}/sdapi/v1/txt2img";
|
||
HttpResponseMessage response = await _sharedHttpClient.PostAsync(
|
||
endpoint,
|
||
content,
|
||
cancellationToken
|
||
);
|
||
|
||
if (!response.IsSuccessStatusCode)
|
||
{
|
||
string errorContent = await response.Content.ReadAsStringAsync();
|
||
Log.Error(
|
||
$"[AI Images] API request failed: {response.StatusCode} - {errorContent}"
|
||
);
|
||
return GenerationResult.Failure($"API Error: {response.StatusCode}");
|
||
}
|
||
|
||
string jsonResponse = await response.Content.ReadAsStringAsync();
|
||
var apiResponse = JsonConvert.DeserializeObject<Txt2ImgResponse>(jsonResponse);
|
||
|
||
if (apiResponse?.images == null || apiResponse.images.Length == 0)
|
||
{
|
||
return GenerationResult.Failure("No images returned from API");
|
||
}
|
||
|
||
// Декодируем изображение из base64
|
||
byte[] imageData = Convert.FromBase64String(apiResponse.images[0]);
|
||
|
||
// Сохраняем изображение
|
||
string fileName = $"pawn_{DateTime.Now:yyyyMMdd_HHmmss}.png";
|
||
string fullPath = Path.Combine(_saveFolderPath, fileName);
|
||
await File.WriteAllBytesAsync(fullPath, imageData);
|
||
|
||
Log.Message($"[AI Images] Image generated successfully and saved to: {fullPath}");
|
||
|
||
return GenerationResult.SuccessResult(imageData, fullPath, request);
|
||
}
|
||
catch (TaskCanceledException)
|
||
{
|
||
Log.Warning("[AI Images] Request timeout. Generation took too long.");
|
||
return GenerationResult.Failure("Request timeout. Generation took too long.");
|
||
}
|
||
catch (OperationCanceledException)
|
||
{
|
||
Log.Message("[AI Images] Image generation was cancelled");
|
||
return GenerationResult.Failure("Generation cancelled");
|
||
}
|
||
catch (HttpRequestException ex)
|
||
{
|
||
Log.Error($"[AI Images] HTTP error: {ex.Message}");
|
||
return GenerationResult.Failure($"Connection error: {ex.Message}");
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
Log.Error($"[AI Images] Unexpected error: {ex.Message}\n{ex.StackTrace}");
|
||
return GenerationResult.Failure($"Error: {ex.Message}");
|
||
}
|
||
}
|
||
|
||
public async Task<bool> CheckApiAvailability(
|
||
string apiEndpoint,
|
||
CancellationToken cancellationToken = default
|
||
)
|
||
{
|
||
ThrowIfDisposed();
|
||
|
||
try
|
||
{
|
||
string endpoint = $"{apiEndpoint}/sdapi/v1/sd-models";
|
||
HttpResponseMessage response = await _sharedHttpClient.GetAsync(
|
||
endpoint,
|
||
cancellationToken
|
||
);
|
||
return response.IsSuccessStatusCode;
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
Log.Warning($"[AI Images] API check failed: {ex.Message}");
|
||
return false;
|
||
}
|
||
}
|
||
|
||
public async Task<List<string>> GetAvailableModels(
|
||
string apiEndpoint,
|
||
CancellationToken cancellationToken = default
|
||
)
|
||
{
|
||
ThrowIfDisposed();
|
||
|
||
try
|
||
{
|
||
string endpoint = $"{apiEndpoint}/sdapi/v1/sd-models";
|
||
HttpResponseMessage response = await _sharedHttpClient.GetAsync(
|
||
endpoint,
|
||
cancellationToken
|
||
);
|
||
|
||
if (!response.IsSuccessStatusCode)
|
||
return new List<string>();
|
||
|
||
string jsonResponse = await response.Content.ReadAsStringAsync();
|
||
var models = JsonConvert.DeserializeObject<List<SdModel>>(jsonResponse);
|
||
|
||
var modelNames = new List<string>();
|
||
if (models != null)
|
||
{
|
||
foreach (var model in models)
|
||
{
|
||
modelNames.Add(model.title ?? model.model_name);
|
||
}
|
||
}
|
||
|
||
Log.Message($"[AI Images] Found {modelNames.Count} models");
|
||
return modelNames;
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
Log.Error($"[AI Images] Failed to load models: {ex.Message}");
|
||
return new List<string>();
|
||
}
|
||
}
|
||
|
||
public async Task<List<string>> GetAvailableSamplers(
|
||
string apiEndpoint,
|
||
CancellationToken cancellationToken = default
|
||
)
|
||
{
|
||
ThrowIfDisposed();
|
||
|
||
try
|
||
{
|
||
string endpoint = $"{apiEndpoint}/sdapi/v1/samplers";
|
||
HttpResponseMessage response = await _sharedHttpClient.GetAsync(
|
||
endpoint,
|
||
cancellationToken
|
||
);
|
||
|
||
if (!response.IsSuccessStatusCode)
|
||
return GetDefaultSamplers();
|
||
|
||
string jsonResponse = await response.Content.ReadAsStringAsync();
|
||
var samplers = JsonConvert.DeserializeObject<List<SdSampler>>(jsonResponse);
|
||
|
||
var samplerNames = new List<string>();
|
||
if (samplers != null)
|
||
{
|
||
foreach (var sampler in samplers)
|
||
{
|
||
samplerNames.Add(sampler.name);
|
||
}
|
||
}
|
||
|
||
Log.Message($"[AI Images] Found {samplerNames.Count} samplers");
|
||
return samplerNames;
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
Log.Warning($"[AI Images] Failed to load samplers: {ex.Message}");
|
||
return GetDefaultSamplers();
|
||
}
|
||
}
|
||
|
||
public async Task<List<string>> GetAvailableSchedulers(
|
||
string apiEndpoint,
|
||
CancellationToken cancellationToken = default
|
||
)
|
||
{
|
||
ThrowIfDisposed();
|
||
|
||
try
|
||
{
|
||
string endpoint = $"{apiEndpoint}/sdapi/v1/schedulers";
|
||
HttpResponseMessage response = await _sharedHttpClient.GetAsync(
|
||
endpoint,
|
||
cancellationToken
|
||
);
|
||
|
||
if (!response.IsSuccessStatusCode)
|
||
return GetDefaultSchedulers();
|
||
|
||
string jsonResponse = await response.Content.ReadAsStringAsync();
|
||
var schedulers = JsonConvert.DeserializeObject<List<SdScheduler>>(jsonResponse);
|
||
|
||
var schedulerNames = new List<string>();
|
||
if (schedulers != null)
|
||
{
|
||
foreach (var scheduler in schedulers)
|
||
{
|
||
schedulerNames.Add(scheduler.name);
|
||
}
|
||
}
|
||
|
||
Log.Message($"[AI Images] Found {schedulerNames.Count} schedulers");
|
||
return schedulerNames;
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
Log.Warning($"[AI Images] Failed to load schedulers: {ex.Message}");
|
||
return GetDefaultSchedulers();
|
||
}
|
||
}
|
||
|
||
private List<string> GetDefaultSamplers()
|
||
{
|
||
return new List<string>
|
||
{
|
||
"Euler a",
|
||
"Euler",
|
||
"LMS",
|
||
"Heun",
|
||
"DPM2",
|
||
"DPM2 a",
|
||
"DPM++ 2S a",
|
||
"DPM++ 2M",
|
||
"DPM++ SDE",
|
||
"DPM fast",
|
||
"DPM adaptive",
|
||
"LMS Karras",
|
||
"DPM2 Karras",
|
||
"DPM2 a Karras",
|
||
"DPM++ 2S a Karras",
|
||
"DPM++ 2M Karras",
|
||
"DPM++ SDE Karras",
|
||
"DDIM",
|
||
"PLMS",
|
||
};
|
||
}
|
||
|
||
private List<string> GetDefaultSchedulers()
|
||
{
|
||
return new List<string>
|
||
{
|
||
"Automatic",
|
||
"Uniform",
|
||
"Karras",
|
||
"Exponential",
|
||
"Polyexponential",
|
||
"SGM Uniform",
|
||
};
|
||
}
|
||
|
||
private void ThrowIfDisposed()
|
||
{
|
||
if (_disposed)
|
||
{
|
||
throw new ObjectDisposedException(nameof(StableDiffusionNetAdapter));
|
||
}
|
||
}
|
||
|
||
public void Dispose()
|
||
{
|
||
if (_disposed)
|
||
return;
|
||
|
||
// Не dispose shared HttpClient - он используется глобально
|
||
_disposed = true;
|
||
}
|
||
|
||
// Вспомогательные классы для десериализации JSON ответов
|
||
#pragma warning disable S3459, S1144, IDE1006 // Properties set by JSON deserializer
|
||
private sealed class Txt2ImgResponse
|
||
{
|
||
public string[] images { get; set; }
|
||
}
|
||
|
||
private sealed class SdModel
|
||
{
|
||
public string title { get; set; }
|
||
public string model_name { get; set; }
|
||
}
|
||
|
||
private sealed class SdSampler
|
||
{
|
||
public string name { get; set; }
|
||
}
|
||
|
||
private sealed class SdScheduler
|
||
{
|
||
public string name { get; set; }
|
||
}
|
||
#pragma warning restore S3459, S1144, IDE1006
|
||
}
|
||
}
|