using ICSharpCode.SharpZipLib.Tar;
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Security;
using System.Security.Authentication;
using System.Threading.Tasks;
using System.Web;
using static Tensorflow.Binding;
namespace Tensorflow.Hub
{
internal static class resolver
{
public enum ModelLoadFormat
{
[Description("COMPRESSED")]
COMPRESSED,
[Description("UNCOMPRESSED")]
UNCOMPRESSED,
[Description("AUTO")]
AUTO
}
public class DownloadManager
{
private readonly string _url;
private double _last_progress_msg_print_time;
private long _total_bytes_downloaded;
private int _max_prog_str;
private bool _interactive_mode()
{
return !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("_TFHUB_DOWNLOAD_PROGRESS"));
}
private void _print_download_progress_msg(string msg, bool flush = false)
{
if (_interactive_mode())
{
// Print progress message to console overwriting previous progress
// message.
_max_prog_str = Math.Max(_max_prog_str, msg.Length);
Console.Write($"\r{msg.PadRight(_max_prog_str)}");
Console.Out.Flush();
//如果flush参数为true,则输出换行符减少干扰交互式界面。
if (flush)
Console.WriteLine();
}
else
{
// Interactive progress tracking is disabled. Print progress to the
// standard TF log.
tf.Logger.Information(msg);
}
}
private void _log_progress(long bytes_downloaded)
{
// Logs progress information about ongoing module download.
_total_bytes_downloaded += bytes_downloaded;
var now = DateTime.Now.Ticks / TimeSpan.TicksPerSecond;
if (_interactive_mode() || now - _last_progress_msg_print_time > 15)
{
// Print progress message every 15 secs or if interactive progress
// tracking is enabled.
_print_download_progress_msg($"Downloading {_url}:" +
$"{tf_utils.bytes_to_readable_str(_total_bytes_downloaded, true)}");
_last_progress_msg_print_time = now;
}
}
public DownloadManager(string url)
{
_url = url;
_last_progress_msg_print_time = DateTime.Now.Ticks / TimeSpan.TicksPerSecond;
_total_bytes_downloaded = 0;
_max_prog_str = 0;
}
public void download_and_uncompress(Stream fileobj, string dst_path)
{
// Streams the content for the 'fileobj' and stores the result in dst_path.
try
{
file_utils.extract_tarfile_to_destination(fileobj, dst_path, _log_progress);
var total_size_str = tf_utils.bytes_to_readable_str(_total_bytes_downloaded, true);
_print_download_progress_msg($"Downloaded {_url}, Total size: {total_size_str}", flush: true);
}
catch (TarException ex)
{
throw new IOException($"{_url} does not appear to be a valid module. Inner message:{ex.Message}", ex);
}
}
}
private static Dictionary _flags = new();
private static readonly string _TFHUB_CACHE_DIR = "TFHUB_CACHE_DIR";
private static readonly string _TFHUB_DOWNLOAD_PROGRESS = "TFHUB_DOWNLOAD_PROGRESS";
private static readonly string _TFHUB_MODEL_LOAD_FORMAT = "TFHUB_MODEL_LOAD_FORMAT";
private static readonly string _TFHUB_DISABLE_CERT_VALIDATION = "TFHUB_DISABLE_CERT_VALIDATION";
private static readonly string _TFHUB_DISABLE_CERT_VALIDATION_VALUE = "true";
static resolver()
{
set_new_flag("tfhub_model_load_format", "AUTO");
set_new_flag("tfhub_cache_dir", null);
}
public static string model_load_format()
{
return get_env_setting(_TFHUB_MODEL_LOAD_FORMAT, "tfhub_model_load_format");
}
public static string? get_env_setting(string env_var, string flag_name)
{
string value = System.Environment.GetEnvironmentVariable(env_var);
if (string.IsNullOrEmpty(value))
{
if (_flags.ContainsKey(flag_name))
{
return _flags[flag_name];
}
else
{
return null;
}
}
else
{
return value;
}
}
public static string tfhub_cache_dir(string default_cache_dir = null, bool use_temp = false)
{
var cache_dir = get_env_setting(_TFHUB_CACHE_DIR, "tfhub_cache_dir") ?? default_cache_dir;
if (string.IsNullOrWhiteSpace(cache_dir) && use_temp)
{
// Place all TF-Hub modules under /tfhub_modules.
cache_dir = Path.Combine(Path.GetTempPath(), "tfhub_modules");
}
if (!string.IsNullOrWhiteSpace(cache_dir))
{
Console.WriteLine("Using {0} to cache modules.", cache_dir);
}
return cache_dir;
}
public static string create_local_module_dir(string cache_dir, string module_name)
{
Directory.CreateDirectory(cache_dir);
return Path.Combine(cache_dir, module_name);
}
public static void set_new_flag(string name, string value)
{
string[] tokens = new string[] {_TFHUB_CACHE_DIR, _TFHUB_DISABLE_CERT_VALIDATION,
_TFHUB_DISABLE_CERT_VALIDATION_VALUE, _TFHUB_DOWNLOAD_PROGRESS, _TFHUB_MODEL_LOAD_FORMAT};
if (!tokens.Contains(name))
{
tf.Logger.Warning($"You are settinng a flag '{name}' that cannot be recognized. The flag you set" +
"may not affect anything in tensorflow.hub.");
}
_flags[name] = value;
}
public static string _merge_relative_path(string dstPath, string relPath)
{
return file_utils.merge_relative_path(dstPath, relPath);
}
public static string _module_descriptor_file(string moduleDir)
{
return $"{moduleDir}.descriptor.txt";
}
public static void _write_module_descriptor_file(string handle, string moduleDir)
{
var readme = _module_descriptor_file(moduleDir);
var content = $"Module: {handle}\nDownload Time: {DateTime.Now}\nDownloader Hostname: {Environment.MachineName} (PID:{Process.GetCurrentProcess().Id})";
tf_utils.atomic_write_string_to_file(readme, content, overwrite: true);
}
public static string _lock_file_contents(string task_uid)
{
return $"{Environment.MachineName}.{Process.GetCurrentProcess().Id}.{task_uid}";
}
public static string _lock_filename(string moduleDir)
{
return tf_utils.absolute_path(moduleDir) + ".lock";
}
private static string _module_dir(string lockFilename)
{
var path = Path.GetDirectoryName(Path.GetFullPath(lockFilename));
if (!string.IsNullOrEmpty(path))
{
return Path.Combine(path, "hub_modules");
}
throw new Exception("Unable to resolve hub_modules directory from lock file name.");
}
private static string _task_uid_from_lock_file(string lockFilename)
{
// Returns task UID of the task that created a given lock file.
var lockstring = File.ReadAllText(lockFilename);
return lockstring.Split('.').Last();
}
private static string _temp_download_dir(string moduleDir, string taskUid)
{
// Returns the name of a temporary directory to download module to.
return $"{Path.GetFullPath(moduleDir)}.{taskUid}.tmp";
}
private static long _dir_size(string directory)
{
// Returns total size (in bytes) of the given 'directory'.
long size = 0;
foreach (var elem in Directory.EnumerateFileSystemEntries(directory))
{
var stat = new FileInfo(elem);
size += stat.Length;
if ((stat.Attributes & FileAttributes.Directory) != 0)
size += _dir_size(stat.FullName);
}
return size;
}
public static long _locked_tmp_dir_size(string lockFilename)
{
//Returns the size of the temp dir pointed to by the given lock file.
var taskUid = _task_uid_from_lock_file(lockFilename);
try
{
return _dir_size(_temp_download_dir(_module_dir(lockFilename), taskUid));
}
catch (DirectoryNotFoundException)
{
return 0;
}
}
private static void _wait_for_lock_to_disappear(string handle, string lockFile, double lockFileTimeoutSec)
{
long? lockedTmpDirSize = null;
var lockedTmpDirSizeCheckTime = DateTime.Now;
var lockFileContent = "";
while (File.Exists(lockFile))
{
try
{
Console.WriteLine($"Module '{handle}' already being downloaded by '{File.ReadAllText(lockFile)}'. Waiting.");
if ((DateTime.Now - lockedTmpDirSizeCheckTime).TotalSeconds > lockFileTimeoutSec)
{
var curLockedTmpDirSize = _locked_tmp_dir_size(lockFile);
var curLockFileContent = File.ReadAllText(lockFile);
if (curLockedTmpDirSize == lockedTmpDirSize && curLockFileContent == lockFileContent)
{
Console.WriteLine($"Deleting lock file {lockFile} due to inactivity.");
File.Delete(lockFile);
break;
}
lockedTmpDirSize = curLockedTmpDirSize;
lockedTmpDirSizeCheckTime = DateTime.Now;
lockFileContent = curLockFileContent;
}
}
catch (FileNotFoundException)
{
// Lock file or temp directory were deleted during check. Continue
// to check whether download succeeded or we need to start our own
// download.
}
System.Threading.Thread.Sleep(5000);
}
}
public static async Task atomic_download_async(
string handle,
Func downloadFn,
string moduleDir,
int lock_file_timeout_sec = 10 * 60)
{
var lockFile = _lock_filename(moduleDir);
var taskUid = Guid.NewGuid().ToString("N");
var lockContents = _lock_file_contents(taskUid);
var tmpDir = _temp_download_dir(moduleDir, taskUid);
// Function to check whether model has already been downloaded.
Func checkModuleExists = () =>
Directory.Exists(moduleDir) &&
Directory.EnumerateFileSystemEntries(moduleDir).Any();
// Check whether the model has already been downloaded before locking
// the destination path.
if (checkModuleExists())
{
return moduleDir;
}
// Attempt to protect against cases of processes being cancelled with
// KeyboardInterrupt by using a try/finally clause to remove the lock
// and tmp_dir.
while (true)
{
try
{
tf_utils.atomic_write_string_to_file(lockFile, lockContents, false);
// Must test condition again, since another process could have created
// the module and deleted the old lock file since last test.
if (checkModuleExists())
{
// Lock file will be deleted in the finally-clause.
return moduleDir;
}
if (Directory.Exists(moduleDir))
{
Directory.Delete(moduleDir, true);
}
break; // Proceed to downloading the module.
}
// These errors are believed to be permanent problems with the
// module_dir that justify failing the download.
catch (FileNotFoundException)
{
throw;
}
catch (UnauthorizedAccessException)
{
throw;
}
catch (IOException)
{
throw;
}
// All other errors are retried.
// TODO(b/144424849): Retrying an AlreadyExistsError from the atomic write
// should be good enough, but see discussion about misc filesystem types.
// TODO(b/144475403): How atomic is the overwrite=False check?
catch (Exception)
{
}
// Wait for lock file to disappear.
_wait_for_lock_to_disappear(handle, lockFile, lock_file_timeout_sec);
// At this point we either deleted a lock or a lock got removed by the
// owner or another process. Perform one more iteration of the while-loop,
// we would either terminate due tf.compat.v1.gfile.Exists(module_dir) or
// because we would obtain a lock ourselves, or wait again for the lock to
// disappear.
}
// Lock file acquired.
tf.Logger.Information($"Downloading TF-Hub Module '{handle}'...");
Directory.CreateDirectory(tmpDir);
await downloadFn(handle, tmpDir);
// Write module descriptor to capture information about which module was
// downloaded by whom and when. The file stored at the same level as a
// directory in order to keep the content of the 'model_dir' exactly as it
// was define by the module publisher.
//
// Note: The descriptor is written purely to help the end-user to identify
// which directory belongs to which module. The descriptor is not part of the
// module caching protocol and no code in the TF-Hub library reads its
// content.
_write_module_descriptor_file(handle, moduleDir);
try
{
Directory.Move(tmpDir, moduleDir);
Console.WriteLine($"Downloaded TF-Hub Module '{handle}'.");
}
catch (IOException e)
{
Console.WriteLine(e.Message);
Console.WriteLine($"Failed to move {tmpDir} to {moduleDir}");
// Keep the temp directory so we will retry building vocabulary later.
}
// Temp directory is owned by the current process, remove it.
try
{
Directory.Delete(tmpDir, true);
}
catch (DirectoryNotFoundException)
{
}
// Lock file exists and is owned by this process.
try
{
var contents = File.ReadAllText(lockFile);
if (contents == lockContents)
{
File.Delete(lockFile);
}
}
catch (Exception)
{
}
return moduleDir;
}
}
internal interface IResolver
{
string Call(string handle);
bool IsSupported(string handle);
}
internal class PathResolver : IResolver
{
public string Call(string handle)
{
if (!File.Exists(handle) && !Directory.Exists(handle))
{
throw new IOException($"{handle} does not exist in file system.");
}
return handle;
}
public bool IsSupported(string handle)
{
return true;
}
}
public abstract class HttpResolverBase : IResolver
{
private readonly HttpClient httpClient;
private SslProtocol sslProtocol;
private RemoteCertificateValidationCallback certificateValidator;
protected HttpResolverBase()
{
httpClient = new HttpClient();
_maybe_disable_cert_validation();
}
public abstract string Call(string handle);
public abstract bool IsSupported(string handle);
protected async Task GetLocalFileStreamAsync(string filePath)
{
try
{
var fs = new FileStream(filePath, FileMode.Open, FileAccess.Read);
return await Task.FromResult(fs);
}
catch (Exception ex)
{
Console.WriteLine($"Failed to read file stream: {ex.Message}");
return null;
}
}
protected async Task GetFileStreamAsync(string filePath)
{
if (!is_http_protocol(filePath))
{
// If filePath is not an HTTP(S) URL, delegate to a file resolver.
return await GetLocalFileStreamAsync(filePath);
}
var request = new HttpRequestMessage(HttpMethod.Get, filePath);
var response = await _call_urlopen(request);
if (response.IsSuccessStatusCode)
{
return await response.Content.ReadAsStreamAsync();
}
else
{
Console.WriteLine($"Failed to fetch file stream: {response.StatusCode} - {response.ReasonPhrase}");
return null;
}
}
protected void SetUrlContext(SslProtocol protocol, RemoteCertificateValidationCallback validator)
{
sslProtocol = protocol;
certificateValidator = validator;
}
public static string append_format_query(string handle, (string, string) formatQuery)
{
var parsed = new Uri(handle);
var queryBuilder = HttpUtility.ParseQueryString(parsed.Query);
queryBuilder.Add(formatQuery.Item1, formatQuery.Item2);
parsed = new UriBuilder(parsed.Scheme, parsed.Host, parsed.Port, parsed.AbsolutePath,
"?" + queryBuilder.ToString()).Uri;
return parsed.ToString();
}
protected bool is_http_protocol(string handle)
{
return handle.StartsWith("http://") || handle.StartsWith("https://");
}
protected async Task _call_urlopen(HttpRequestMessage request)
{
if (sslProtocol != null)
{
var handler = new HttpClientHandler()
{
SslProtocols = sslProtocol.AsEnum(),
};
if (certificateValidator != null)
{
handler.ServerCertificateCustomValidationCallback = (x, y, z, w) =>
{
return certificateValidator(x, y, z, w);
};
}
var client = new HttpClient(handler);
return await client.SendAsync(request);
}
else
{
return await httpClient.SendAsync(request);
}
}
protected void _maybe_disable_cert_validation()
{
if (Environment.GetEnvironmentVariable("_TFHUB_DISABLE_CERT_VALIDATION") == "_TFHUB_DISABLE_CERT_VALIDATION_VALUE")
{
ServicePointManager.ServerCertificateValidationCallback = (_, _, _, _) => true;
Console.WriteLine("Disabled certificate validation for resolving handles.");
}
}
}
public class SslProtocol
{
private readonly string protocolString;
public static readonly SslProtocol Tls = new SslProtocol("TLS");
public static readonly SslProtocol Tls11 = new SslProtocol("TLS 1.1");
public static readonly SslProtocol Tls12 = new SslProtocol("TLS 1.2");
private SslProtocol(string protocolString)
{
this.protocolString = protocolString;
}
public SslProtocols AsEnum()
{
switch (protocolString.ToUpper())
{
case "TLS":
return SslProtocols.Tls;
case "TLS 1.1":
return SslProtocols.Tls11;
case "TLS 1.2":
return SslProtocols.Tls12;
default:
throw new ArgumentException($"Unknown SSL/TLS protocol: {protocolString}");
}
}
}
}